diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index 67133e120..4d96b4000 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -4,8 +4,6 @@ package keys_test import ( - "crypto/rand" - "crypto/rsa" "encoding/json" "fmt" "io" @@ -21,11 +19,8 @@ import ( "github.com/absmach/supermq/auth/mocks" smqlog "github.com/absmach/supermq/logger" svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) const ( @@ -323,17 +318,17 @@ func TestRetrieveJWKS(t *testing.T) { cases := []struct { desc string - svcRes []auth.JWK + svcRes []auth.PublicKeyInfo status int }{ { desc: "retrieve JWKS with keys", - svcRes: []auth.JWK{newJWK(t), newJWK(t)}, + svcRes: []auth.PublicKeyInfo{newPublicKeyInfo(), newPublicKeyInfo()}, status: http.StatusOK, }, { desc: "retrieve empty JWKS", - svcRes: []auth.JWK{}, + svcRes: []auth.PublicKeyInfo{}, status: http.StatusOK, }, } @@ -352,14 +347,13 @@ func TestRetrieveJWKS(t *testing.T) { } } -func newJWK(t *testing.T) auth.JWK { - pKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err)) - jwkKey, err := jwk.FromRaw(&pKey.PublicKey) - require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err)) - err = jwkKey.Set(jwk.KeyIDKey, "test-key-id") - require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err)) - err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String()) - require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err)) - return auth.NewJWK(jwkKey) +func newPublicKeyInfo() auth.PublicKeyInfo { + return auth.PublicKeyInfo{ + KeyID: "test-key-id", + KeyType: "OKP", + Algorithm: "EdDSA", + Use: "sig", + Curve: "Ed25519", + X: "base64url-encoded-public-key", + } } diff --git a/auth/api/http/keys/responses.go b/auth/api/http/keys/responses.go index 2e802f526..b0d926a73 100644 --- a/auth/api/http/keys/responses.go +++ b/auth/api/http/keys/responses.go @@ -11,7 +11,6 @@ import ( "github.com/absmach/supermq" "github.com/absmach/supermq/auth" - "github.com/lestrrat-go/jwx/v2/jwk" ) var ( @@ -76,20 +75,17 @@ func (res revokeKeyRes) Empty() bool { } type retrieveJWKSRes struct { - Keys []auth.JWK `json:"-"` - CacheMaxAge int `json:"-"` - CacheStaleWhileRevalidate int `json:"-"` + Keys []auth.PublicKeyInfo `json:"-"` + CacheMaxAge int `json:"-"` + CacheStaleWhileRevalidate int `json:"-"` } func (res retrieveJWKSRes) MarshalJSON() ([]byte, error) { - set := jwk.NewSet() - for _, k := range res.Keys { - if err := set.AddKey(k.Key()); err != nil { - return nil, err - } + type jwksResponse struct { + Keys []auth.PublicKeyInfo `json:"keys"` } - return json.Marshal(set) + return json.Marshal(jwksResponse{Keys: res.Keys}) } func (res retrieveJWKSRes) Code() int { diff --git a/auth/jwt/token_test.go b/auth/jwt/token_test.go deleted file mode 100644 index a19b34906..000000000 --- a/auth/jwt/token_test.go +++ /dev/null @@ -1,341 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package jwt_test - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "fmt" - "testing" - "time" - - "github.com/absmach/supermq/auth" - authjwt "github.com/absmach/supermq/auth/jwt" - "github.com/absmach/supermq/auth/mocks" - "github.com/absmach/supermq/internal/testsutil" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -const ( - tokenType = "type" - roleField = "role" - VerifiedField = "verified" - issuerName = "supermq.auth" -) - -var ( - errJWTExpiryKey = errors.New(`"exp" not satisfied`) - keyManager = new(mocks.KeyManager) -) - -func TestIssue(t *testing.T) { - tokenizer := authjwt.New(keyManager) - - validKey := key() - signedToken, _, err := signToken(issuerName, validKey, false) - require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) - - cases := []struct { - desc string - key auth.Key - managerReq jwt.Token - managerResp []byte - managerErr error - err error - }{ - { - desc: "issue new token", - key: validKey, - managerResp: []byte(signedToken), - err: nil, - }, - { - desc: "issue token with OAuth token", - key: auth.Key{ - ID: testsutil.GenerateUUID(t), - Type: auth.AccessKey, - Subject: testsutil.GenerateUUID(t), - IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second), - ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second), - }, - managerResp: []byte(signedToken), - err: nil, - }, - { - desc: "issue token without a domain", - key: auth.Key{ - ID: testsutil.GenerateUUID(t), - Type: auth.AccessKey, - Subject: testsutil.GenerateUUID(t), - IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second), - }, - managerResp: []byte(signedToken), - err: nil, - }, - { - desc: "issue token without a subject", - key: auth.Key{ - ID: testsutil.GenerateUUID(t), - Type: auth.AccessKey, - Subject: "", - IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second), - }, - managerResp: []byte(signedToken), - err: nil, - }, - { - desc: "issue token without type", - key: auth.Key{ - ID: testsutil.GenerateUUID(t), - Type: auth.KeyType(auth.InvitationKey + 1), - Subject: testsutil.GenerateUUID(t), - IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second), - }, - managerResp: []byte(signedToken), - err: nil, - }, - { - desc: "issue token without a domain and subject", - key: auth.Key{ - ID: testsutil.GenerateUUID(t), - Type: auth.AccessKey, - Subject: "", - IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second), - ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second), - }, - managerResp: []byte(signedToken), - err: nil, - }, - { - desc: "issue token with failed to sign jwt", - key: validKey, - managerErr: svcerr.ErrAuthentication, - err: authjwt.ErrSignJWT, - }, - } - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - tc.managerReq = newToken(issuerName, tc.key) - kmCall := keyManager.On("SignJWT", tc.managerReq).Return(tc.managerResp, tc.managerErr) - tkn, err := tokenizer.Issue(tc.key) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) - if err == nil { - assert.NotEmpty(t, tkn, fmt.Sprintf("%s expected token, got empty string", tc.desc)) - } - kmCall.Unset() - }) - } -} - -func TestParse(t *testing.T) { - tokenizer := authjwt.New(keyManager) - - validKey := key() - signedTkn, parsedTkn, err := signToken(issuerName, validKey, true) - require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) - - apiKey := key() - apiKey.Type = auth.APIKey - apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) - apiToken, _, err := signToken(issuerName, apiKey, false) - require.Nil(t, err, fmt.Sprintf("issuing api key expected to succeed: %s", err)) - - expKey := key() - expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) - expToken, _, err := signToken(issuerName, expKey, false) - require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err)) - - emptySubjectKey := key() - emptySubjectKey.Subject = "" - signedEmptySubjectTkn, parsedEmptySubjectTkn, err := signToken(issuerName, emptySubjectKey, true) - require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) - - emptyTypeKey := key() - emptyTypeKey.Type = auth.KeyType(auth.InvitationKey + 1) - emptyTypeToken, _, err := signToken(issuerName, emptyTypeKey, false) - require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) - - emptyKey := key() - emptyKey.Subject = "" - - signedInValidTkn, parsedInvalidTkn, err := signToken("invalid.issuer", key(), true) - require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) - - cases := []struct { - desc string - key auth.Key - token string - managerRes jwt.Token - managerErr error - err error - }{ - { - desc: "parse valid key", - key: validKey, - token: signedTkn, - managerRes: parsedTkn, - err: nil, - }, - { - desc: "parse invalid key", - key: auth.Key{}, - token: "invalid", - managerErr: svcerr.ErrAuthentication, - err: svcerr.ErrAuthentication, - }, - { - desc: "parse expired key", - key: auth.Key{}, - token: expToken, - managerErr: errJWTExpiryKey, - err: auth.ErrExpiry, - }, - { - desc: "parse expired API key", - key: apiKey, - token: apiToken, - managerErr: errJWTExpiryKey, - err: auth.ErrExpiry, - }, - { - desc: "parse token with invalid issuer", - key: auth.Key{}, - token: signedInValidTkn, - managerRes: parsedInvalidTkn, - err: svcerr.ErrAuthentication, - }, - { - desc: "parse token with empty subject", - key: emptySubjectKey, - token: signedEmptySubjectTkn, - managerRes: parsedEmptySubjectTkn, - err: nil, - }, - { - desc: "parse token with empty type", - key: emptyTypeKey, - token: emptyTypeToken, - managerRes: newToken(issuerName, emptyKey), - err: svcerr.ErrAuthentication, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - kmCall := keyManager.On("ParseJWT", tc.token).Return(tc.managerRes, tc.managerErr) - key, err := tokenizer.Parse(context.Background(), tc.token) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) - if err == nil { - assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key)) - } - kmCall.Unset() - }) - } -} - -func TestRetrieveJWKS(t *testing.T) { - tokenizer := authjwt.New(keyManager) - - cases := []struct { - desc string - keys []auth.JWK - retrieveErr error - err error - }{ - { - desc: "retrieve jwks with keys", - keys: []auth.JWK{newJWK(t), newJWK(t)}, - }, - { - desc: "retrieve empty jwks", - keys: []auth.JWK{}, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - kmCall := keyManager.On("PublicJWKS", mock.Anything).Return(tc.keys, tc.retrieveErr) - jwks := tokenizer.RetrieveJWKS() - assert.Equal(t, tc.keys, jwks, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.keys, jwks)) - kmCall.Unset() - }) - } -} - -func key() auth.Key { - exp := time.Now().UTC().Add(10 * time.Minute).Round(time.Second) - return auth.Key{ - ID: "66af4a67-3823-438a-abd7-efdb613eaef6", - Type: auth.AccessKey, - Issuer: "supermq.auth", - Role: auth.UserRole, - Subject: "66af4a67-3823-438a-abd7-efdb613eaef6", - IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second), - ExpiresAt: exp, - } -} - -func newToken(issuerName string, key auth.Key) jwt.Token { - builder := jwt.NewBuilder() - builder. - Issuer(issuerName). - IssuedAt(key.IssuedAt). - Claim(tokenType, key.Type). - Expiration(key.ExpiresAt) - builder.Claim(roleField, key.Role) - builder.Claim(VerifiedField, key.Verified) - if key.Subject != "" { - builder.Subject(key.Subject) - } - if key.ID != "" { - builder.JwtID(key.ID) - } - tkn, _ := builder.Build() - return tkn -} - -func signToken(issuerName string, key auth.Key, parseToken bool) (string, jwt.Token, error) { - tkn := newToken(issuerName, key) - pKey, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - return "", nil, err - } - pubKey := &pKey.PublicKey - sTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.RS256, pKey)) - if err != nil { - return "", nil, err - } - if !parseToken { - return string(sTkn), nil, nil - } - pTkn, err := jwt.Parse( - sTkn, - jwt.WithValidate(true), - jwt.WithKey(jwa.RS256, pubKey), - ) - if err != nil { - return "", nil, err - } - return string(sTkn), pTkn, nil -} - -func newJWK(t *testing.T) auth.JWK { - pKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err)) - jwkKey, err := jwk.FromRaw(&pKey.PublicKey) - require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err)) - err = jwkKey.Set(jwk.KeyIDKey, "test-key-id") - require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err)) - err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String()) - require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err)) - return auth.NewJWK(jwkKey) -} diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go deleted file mode 100644 index 578fed929..000000000 --- a/auth/jwt/tokenizer.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package jwt - -import ( - "context" - "encoding/json" - - "github.com/absmach/supermq/auth" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/lestrrat-go/jwx/v2/jwt" -) - -var ( - // errInvalidIssuer is returned when the issuer is not supermq.auth. - errInvalidIssuer = errors.New("invalid token issuer value") - // errInvalidType is returned when there is no type field. - errInvalidType = errors.New("invalid token type") - // errInvalidRole is returned when the role is invalid. - errInvalidRole = errors.New("invalid role") - // errInvalidVerified is returned when the verified is invalid. - errInvalidVerified = errors.New("invalid verified") - // errJWTExpiryKey is used to check if the token is expired. - errJWTExpiryKey = errors.New(`"exp" not satisfied`) - // ErrSignJWT indicates an error in signing jwt token. - ErrSignJWT = errors.New("failed to sign jwt token") - // ErrValidateJWTToken indicates a failure to validate JWT token. - ErrValidateJWTToken = errors.New("failed to validate jwt token") - // ErrJSONHandle indicates an error in handling JSON. - ErrJSONHandle = errors.New("failed to perform operation JSON") -) - -const ( - issuerName = "supermq.auth" - tokenType = "type" - RoleField = "role" - VerifiedField = "verified" - patPrefix = "pat" -) - -type tokenizer struct { - keyManager auth.KeyManager -} - -var _ auth.Tokenizer = (*tokenizer)(nil) - -// New instantiates an implementation of Tokenizer service. -func New(keyManager auth.KeyManager) auth.Tokenizer { - return &tokenizer{ - keyManager: keyManager, - } -} - -func (tok *tokenizer) Issue(key auth.Key) (string, error) { - builder := jwt.NewBuilder() - builder. - Issuer(issuerName). - IssuedAt(key.IssuedAt). - Claim(tokenType, key.Type). - Expiration(key.ExpiresAt) - builder.Claim(RoleField, key.Role) - builder.Claim(VerifiedField, key.Verified) - if key.Subject != "" { - builder.Subject(key.Subject) - } - if key.ID != "" { - builder.JwtID(key.ID) - } - tkn, err := builder.Build() - if err != nil { - return "", errors.Wrap(svcerr.ErrAuthentication, err) - } - signedTkn, err := tok.keyManager.SignJWT(tkn) - if err != nil { - return "", errors.Wrap(ErrSignJWT, err) - } - return string(signedTkn), nil -} - -func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) { - if len(token) >= 3 && token[:3] == patPrefix { - return auth.Key{Type: auth.PersonalAccessToken}, nil - } - - tkn, err := tok.validateToken(token) - if err != nil { - return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) - } - - key, err := ToKey(tkn) - if err != nil { - return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) - } - - return key, nil -} - -func (tok *tokenizer) validateToken(token string) (jwt.Token, error) { - tkn, err := tok.keyManager.ParseJWT(token) - if err != nil { - if errors.Contains(err, errJWTExpiryKey) { - return nil, auth.ErrExpiry - } - - return nil, err - } - validator := jwt.ValidatorFunc(func(_ context.Context, t jwt.Token) jwt.ValidationError { - if t.Issuer() != issuerName { - return jwt.NewValidationError(errInvalidIssuer) - } - return nil - }) - if err := jwt.Validate(tkn, jwt.WithValidator(validator)); err != nil { - return nil, errors.Wrap(ErrValidateJWTToken, err) - } - - return tkn, nil -} - -func (tok *tokenizer) RetrieveJWKS() []auth.JWK { - return tok.keyManager.PublicJWKS() -} - -func ToKey(tkn jwt.Token) (auth.Key, error) { - data, err := json.Marshal(tkn.PrivateClaims()) - if err != nil { - return auth.Key{}, errors.Wrap(ErrJSONHandle, err) - } - var key auth.Key - if err := json.Unmarshal(data, &key); err != nil { - return auth.Key{}, errors.Wrap(ErrJSONHandle, err) - } - - tType, ok := tkn.Get(tokenType) - if !ok { - return auth.Key{}, errInvalidType - } - kType, ok := tType.(float64) - if !ok { - return auth.Key{}, errInvalidType - } - kt := auth.KeyType(kType) - if !kt.Validate() { - return auth.Key{}, errInvalidType - } - - tRole, ok := tkn.Get(RoleField) - if !ok { - return auth.Key{}, errInvalidRole - } - kRole, ok := tRole.(float64) - if !ok { - return auth.Key{}, errInvalidRole - } - - tVerified, ok := tkn.Get(VerifiedField) - if !ok { - return auth.Key{}, errInvalidVerified - } - kVerified, ok := tVerified.(bool) - if !ok { - return auth.Key{}, errInvalidVerified - } - - kr := auth.Role(kRole) - if !kr.Validate() { - return auth.Key{}, errInvalidRole - } - - key.ID = tkn.JwtID() - key.Type = auth.KeyType(kType) - key.Role = auth.Role(kRole) - key.Issuer = tkn.Issuer() - key.Subject = tkn.Subject() - key.IssuedAt = tkn.IssuedAt() - key.ExpiresAt = tkn.Expiration() - key.Verified = kVerified - - return key, nil -} diff --git a/auth/key_manager.go b/auth/key_manager.go index 9806582ac..73c74db82 100644 --- a/auth/key_manager.go +++ b/auth/key_manager.go @@ -4,47 +4,57 @@ package auth import ( + "context" "errors" - - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jwt" ) var ( ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm") ErrInvalidSymmetricKey = errors.New("invalid symmetric key") + ErrPublicKeysNotSupported = errors.New("public keys not supported for symmetric algorithm") ) -// JWK represents a JSON Web Key. -type JWK struct { - key jwk.Key +// PublicKeyInfo represents a public key for external distribution via JWKS. +// This follows RFC 7517 (JSON Web Key) specification. +type PublicKeyInfo struct { + KeyID string `json:"kid"` + KeyType string `json:"kty"` + Algorithm string `json:"alg"` + Use string `json:"use,omitempty"` + + // EdDSA (Ed25519) fields + Curve string `json:"crv,omitempty"` + X string `json:"x,omitempty"` + + // Future: RSA fields (n, e), ECDSA fields (x, y, crv), etc. } -// NewJWK creates a new JWK from a jwk.Key. -func NewJWK(key jwk.Key) JWK { - return JWK{key: key} -} - -// Key returns the underlying jwk.Key. -func (j JWK) Key() jwk.Key { - return j.key -} - -// KeyManager represents a manager for JWT keys. -type KeyManager interface { - SignJWT(token jwt.Token) ([]byte, error) - - ParseJWT(token string) (jwt.Token, error) - - PublicJWKS() []JWK +// Tokenizer handles token creation and verification for authentication. +// Implementations manage underlying cryptographic operations and key distribution. +type Tokenizer interface { + // Issue creates a signed token string from the given key claims. + Issue(key Key) (token string, err error) + + // Parse verifies and parses a token string (JWT or PAT), returning the extracted claims. + // For PAT tokens (prefix "pat"), returns a Key with Type set to PersonalAccessToken. + // For JWT tokens, performs cryptographic verification and returns the parsed claims. + Parse(ctx context.Context, token string) (key Key, err error) + + // RetrieveJWKS returns public keys for distribution via JWKS endpoint. + // Returns ErrPublicKeysNotSupported for symmetric tokenizers (HMAC). + RetrieveJWKS() ([]PublicKeyInfo, error) } +// IsSymmetricAlgorithm determines if the given algorithm is symmetric (HMAC-based). +// Returns true for HMAC algorithms (HS256, HS384, HS512). +// Returns false for asymmetric algorithms (EdDSA). +// Returns error for unsupported algorithms. func IsSymmetricAlgorithm(alg string) (bool, error) { switch alg { - case "HS256", "HS384", "HS512": - return true, nil case "EdDSA": return false, nil + case "HS256", "HS384", "HS512": + return true, nil default: return false, ErrUnsupportedKeyAlgorithm } diff --git a/auth/keymanager/asymmetric/key_manager.go b/auth/keymanager/asymmetric/key_manager.go deleted file mode 100644 index 7eeba6134..000000000 --- a/auth/keymanager/asymmetric/key_manager.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package asymmetric - -import ( - "crypto/ed25519" - "crypto/x509" - "encoding/pem" - "errors" - "os" - - "github.com/absmach/supermq" - "github.com/absmach/supermq/auth" - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jws" - "github.com/lestrrat-go/jwx/v2/jwt" -) - -var ( - errLoadingPrivateKey = errors.New("failed to load private key") - errInvalidKeySize = errors.New("invalid ED25519 key size") - errParsingPrivateKey = errors.New("failed to parse private key") - errInvalidKeyType = errors.New("private key is not ED25519") - errGeneratingKID = errors.New("failed to generate key ID") -) - -type manager struct { - privateKey jwk.Key - publicKey jwk.Key - kid string -} - -var _ auth.KeyManager = (*manager)(nil) - -// NewKeyManager creates a new asymmetric key manager that loads the private key from a file. -func NewKeyManager(privateKeyPath string, idProvider supermq.IDProvider) (auth.KeyManager, error) { - kid, err := idProvider.ID() - if err != nil { - return nil, errors.Join(errGeneratingKID, err) - } - - privateJwk, publicJwk, err := loadKeyPair(privateKeyPath, kid) - if err != nil { - return nil, err - } - - return &manager{ - privateKey: privateJwk, - publicKey: publicJwk, - kid: kid, - }, nil -} - -func (km *manager) SignJWT(token jwt.Token) ([]byte, error) { - return jwt.Sign(token, jwt.WithKey(jwa.EdDSA, km.privateKey)) -} - -func (km *manager) ParseJWT(token string) (jwt.Token, error) { - set := jwk.NewSet() - if err := set.AddKey(km.publicKey); err != nil { - return nil, err - } - - tkn, err := jwt.Parse( - []byte(token), - jwt.WithValidate(true), - jwt.WithKeySet(set, jws.WithInferAlgorithmFromKey(true)), - ) - if err != nil { - return nil, err - } - return tkn, nil -} - -func (km *manager) PublicJWKS() []auth.JWK { - return []auth.JWK{auth.NewJWK(km.publicKey)} -} - -func loadKeyPair(privateKeyPath string, kid string) (jwk.Key, jwk.Key, error) { - privateKeyBytes, err := os.ReadFile(privateKeyPath) - if err != nil { - return nil, nil, errors.Join(errLoadingPrivateKey, err) - } - - var privateKey ed25519.PrivateKey - block, _ := pem.Decode(privateKeyBytes) - switch { - case block != nil: - parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) - if err != nil { - return nil, nil, errors.Join(errParsingPrivateKey, err) - } - var ok bool - privateKey, ok = parsedKey.(ed25519.PrivateKey) - if !ok { - return nil, nil, errInvalidKeyType - } - default: - if len(privateKeyBytes) != ed25519.PrivateKeySize { - return nil, nil, errInvalidKeySize - } - privateKey = ed25519.PrivateKey(privateKeyBytes) - } - - publicKey := privateKey.Public().(ed25519.PublicKey) - - privateJwk, err := jwk.FromRaw(privateKey) - if err != nil { - return nil, nil, err - } - if err := privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil { - return nil, nil, err - } - if err := privateJwk.Set(jwk.KeyIDKey, kid); err != nil { - return nil, nil, err - } - - publicJwk, err := jwk.FromRaw(publicKey) - if err != nil { - return nil, nil, err - } - if err := publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil { - return nil, nil, err - } - if err := publicJwk.Set(jwk.KeyIDKey, kid); err != nil { - return nil, nil, err - } - - return privateJwk, publicJwk, nil -} diff --git a/auth/keymanager/symmetric/key_manager.go b/auth/keymanager/symmetric/key_manager.go deleted file mode 100644 index facaea253..000000000 --- a/auth/keymanager/symmetric/key_manager.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package symmetric - -import ( - "github.com/absmach/supermq/auth" - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwt" -) - -type manager struct { - algorithm jwa.KeyAlgorithm - secret []byte -} - -var _ auth.KeyManager = (*manager)(nil) - -func NewKeyManager(algorithm string, secret []byte) (auth.KeyManager, error) { - alg := jwa.KeyAlgorithmFrom(algorithm) - if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok { - return nil, auth.ErrUnsupportedKeyAlgorithm - } - if len(secret) == 0 { - return nil, auth.ErrInvalidSymmetricKey - } - return &manager{ - secret: secret, - algorithm: alg, - }, nil -} - -func (km *manager) SignJWT(token jwt.Token) ([]byte, error) { - return jwt.Sign(token, jwt.WithKey(km.algorithm, km.secret)) -} - -func (km *manager) ParseJWT(token string) (jwt.Token, error) { - return jwt.Parse( - []byte(token), - jwt.WithValidate(true), - jwt.WithKey(km.algorithm, km.secret), - ) -} - -func (km *manager) PublicJWKS() []auth.JWK { - return nil -} diff --git a/auth/middleware/logging.go b/auth/middleware/logging.go index f4a8128d4..b2b8df685 100644 --- a/auth/middleware/logging.go +++ b/auth/middleware/logging.go @@ -100,7 +100,7 @@ func (lm *loggingMiddleware) Identify(ctx context.Context, token string) (id aut return lm.svc.Identify(ctx, token) } -func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.JWK) { +func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.PublicKeyInfo) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), diff --git a/auth/middleware/metrics.go b/auth/middleware/metrics.go index c7aed19e6..e8c2560bc 100644 --- a/auth/middleware/metrics.go +++ b/auth/middleware/metrics.go @@ -67,7 +67,7 @@ func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (auth.K return ms.svc.Identify(ctx, token) } -func (ms *metricsMiddleware) RetrieveJWKS() []auth.JWK { +func (ms *metricsMiddleware) RetrieveJWKS() []auth.PublicKeyInfo { defer func(begin time.Time) { ms.counter.With("method", "retrieve_jwks").Add(1) ms.latency.With("method", "retrieve_jwks").Observe(time.Since(begin).Seconds()) diff --git a/auth/middleware/tracing.go b/auth/middleware/tracing.go index 311a6ef3a..3855b1b1f 100644 --- a/auth/middleware/tracing.go +++ b/auth/middleware/tracing.go @@ -61,7 +61,7 @@ func (tm *tracingMiddleware) Identify(ctx context.Context, token string) (auth.K return tm.svc.Identify(ctx, token) } -func (tm *tracingMiddleware) RetrieveJWKS() []auth.JWK { +func (tm *tracingMiddleware) RetrieveJWKS() []auth.PublicKeyInfo { return tm.svc.RetrieveJWKS() } diff --git a/auth/mocks/key_manager.go b/auth/mocks/key_manager.go deleted file mode 100644 index c858f412b..000000000 --- a/auth/mocks/key_manager.go +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright (c) Abstract Machines - -// SPDX-License-Identifier: Apache-2.0 - -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package mocks - -import ( - "github.com/absmach/supermq/auth" - "github.com/lestrrat-go/jwx/v2/jwt" - mock "github.com/stretchr/testify/mock" -) - -// NewKeyManager creates a new instance of KeyManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewKeyManager(t interface { - mock.TestingT - Cleanup(func()) -}) *KeyManager { - mock := &KeyManager{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// KeyManager is an autogenerated mock type for the KeyManager type -type KeyManager struct { - mock.Mock -} - -type KeyManager_Expecter struct { - mock *mock.Mock -} - -func (_m *KeyManager) EXPECT() *KeyManager_Expecter { - return &KeyManager_Expecter{mock: &_m.Mock} -} - -// ParseJWT provides a mock function for the type KeyManager -func (_mock *KeyManager) ParseJWT(token string) (jwt.Token, error) { - ret := _mock.Called(token) - - if len(ret) == 0 { - panic("no return value specified for ParseJWT") - } - - var r0 jwt.Token - var r1 error - if returnFunc, ok := ret.Get(0).(func(string) (jwt.Token, error)); ok { - return returnFunc(token) - } - if returnFunc, ok := ret.Get(0).(func(string) jwt.Token); ok { - r0 = returnFunc(token) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(jwt.Token) - } - } - if returnFunc, ok := ret.Get(1).(func(string) error); ok { - r1 = returnFunc(token) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// KeyManager_ParseJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ParseJWT' -type KeyManager_ParseJWT_Call struct { - *mock.Call -} - -// ParseJWT is a helper method to define mock.On call -// - token string -func (_e *KeyManager_Expecter) ParseJWT(token interface{}) *KeyManager_ParseJWT_Call { - return &KeyManager_ParseJWT_Call{Call: _e.mock.On("ParseJWT", token)} -} - -func (_c *KeyManager_ParseJWT_Call) Run(run func(token string)) *KeyManager_ParseJWT_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 string - if args[0] != nil { - arg0 = args[0].(string) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *KeyManager_ParseJWT_Call) Return(token1 jwt.Token, err error) *KeyManager_ParseJWT_Call { - _c.Call.Return(token1, err) - return _c -} - -func (_c *KeyManager_ParseJWT_Call) RunAndReturn(run func(token string) (jwt.Token, error)) *KeyManager_ParseJWT_Call { - _c.Call.Return(run) - return _c -} - -// PublicJWKS provides a mock function for the type KeyManager -func (_mock *KeyManager) PublicJWKS() []auth.JWK { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for PublicJWKS") - } - - var r0 []auth.JWK - if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]auth.JWK) - } - } - return r0 -} - -// KeyManager_PublicJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublicJWKS' -type KeyManager_PublicJWKS_Call struct { - *mock.Call -} - -// PublicJWKS is a helper method to define mock.On call -func (_e *KeyManager_Expecter) PublicJWKS() *KeyManager_PublicJWKS_Call { - return &KeyManager_PublicJWKS_Call{Call: _e.mock.On("PublicJWKS")} -} - -func (_c *KeyManager_PublicJWKS_Call) Run(run func()) *KeyManager_PublicJWKS_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *KeyManager_PublicJWKS_Call) Return(jWKs []auth.JWK) *KeyManager_PublicJWKS_Call { - _c.Call.Return(jWKs) - return _c -} - -func (_c *KeyManager_PublicJWKS_Call) RunAndReturn(run func() []auth.JWK) *KeyManager_PublicJWKS_Call { - _c.Call.Return(run) - return _c -} - -// SignJWT provides a mock function for the type KeyManager -func (_mock *KeyManager) SignJWT(token jwt.Token) ([]byte, error) { - ret := _mock.Called(token) - - if len(ret) == 0 { - panic("no return value specified for SignJWT") - } - - var r0 []byte - var r1 error - if returnFunc, ok := ret.Get(0).(func(jwt.Token) ([]byte, error)); ok { - return returnFunc(token) - } - if returnFunc, ok := ret.Get(0).(func(jwt.Token) []byte); ok { - r0 = returnFunc(token) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - if returnFunc, ok := ret.Get(1).(func(jwt.Token) error); ok { - r1 = returnFunc(token) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// KeyManager_SignJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignJWT' -type KeyManager_SignJWT_Call struct { - *mock.Call -} - -// SignJWT is a helper method to define mock.On call -// - token jwt.Token -func (_e *KeyManager_Expecter) SignJWT(token interface{}) *KeyManager_SignJWT_Call { - return &KeyManager_SignJWT_Call{Call: _e.mock.On("SignJWT", token)} -} - -func (_c *KeyManager_SignJWT_Call) Run(run func(token jwt.Token)) *KeyManager_SignJWT_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 jwt.Token - if args[0] != nil { - arg0 = args[0].(jwt.Token) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *KeyManager_SignJWT_Call) Return(bytes []byte, err error) *KeyManager_SignJWT_Call { - _c.Call.Return(bytes, err) - return _c -} - -func (_c *KeyManager_SignJWT_Call) RunAndReturn(run func(token jwt.Token) ([]byte, error)) *KeyManager_SignJWT_Call { - _c.Call.Return(run) - return _c -} diff --git a/auth/mocks/service.go b/auth/mocks/service.go index 11fd23708..4482b6df7 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -1029,19 +1029,19 @@ func (_c *Service_ResetPATSecret_Call) RunAndReturn(run func(ctx context.Context } // RetrieveJWKS provides a mock function for the type Service -func (_mock *Service) RetrieveJWKS() []auth.JWK { +func (_mock *Service) RetrieveJWKS() []auth.PublicKeyInfo { ret := _mock.Called() if len(ret) == 0 { panic("no return value specified for RetrieveJWKS") } - var r0 []auth.JWK - if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok { + var r0 []auth.PublicKeyInfo + if returnFunc, ok := ret.Get(0).(func() []auth.PublicKeyInfo); ok { r0 = returnFunc() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]auth.JWK) + r0 = ret.Get(0).([]auth.PublicKeyInfo) } } return r0 @@ -1064,12 +1064,12 @@ func (_c *Service_RetrieveJWKS_Call) Run(run func()) *Service_RetrieveJWKS_Call return _c } -func (_c *Service_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Service_RetrieveJWKS_Call { - _c.Call.Return(jWKs) +func (_c *Service_RetrieveJWKS_Call) Return(publicKeyInfos []auth.PublicKeyInfo) *Service_RetrieveJWKS_Call { + _c.Call.Return(publicKeyInfos) return _c } -func (_c *Service_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Service_RetrieveJWKS_Call { +func (_c *Service_RetrieveJWKS_Call) RunAndReturn(run func() []auth.PublicKeyInfo) *Service_RetrieveJWKS_Call { _c.Call.Return(run) return _c } diff --git a/auth/mocks/tokenizer.go b/auth/mocks/tokenizer.go index ce5a09057..f4b5c786f 100644 --- a/auth/mocks/tokenizer.go +++ b/auth/mocks/tokenizer.go @@ -169,22 +169,31 @@ func (_c *Tokenizer_Parse_Call) RunAndReturn(run func(ctx context.Context, token } // RetrieveJWKS provides a mock function for the type Tokenizer -func (_mock *Tokenizer) RetrieveJWKS() []auth.JWK { +func (_mock *Tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { ret := _mock.Called() if len(ret) == 0 { panic("no return value specified for RetrieveJWKS") } - var r0 []auth.JWK - if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok { + var r0 []auth.PublicKeyInfo + var r1 error + if returnFunc, ok := ret.Get(0).(func() ([]auth.PublicKeyInfo, error)); ok { + return returnFunc() + } + if returnFunc, ok := ret.Get(0).(func() []auth.PublicKeyInfo); ok { r0 = returnFunc() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]auth.JWK) + r0 = ret.Get(0).([]auth.PublicKeyInfo) } } - return r0 + if returnFunc, ok := ret.Get(1).(func() error); ok { + r1 = returnFunc() + } else { + r1 = ret.Error(1) + } + return r0, r1 } // Tokenizer_RetrieveJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveJWKS' @@ -204,12 +213,12 @@ func (_c *Tokenizer_RetrieveJWKS_Call) Run(run func()) *Tokenizer_RetrieveJWKS_C return _c } -func (_c *Tokenizer_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Tokenizer_RetrieveJWKS_Call { - _c.Call.Return(jWKs) +func (_c *Tokenizer_RetrieveJWKS_Call) Return(publicKeyInfos []auth.PublicKeyInfo, err error) *Tokenizer_RetrieveJWKS_Call { + _c.Call.Return(publicKeyInfos, err) return _c } -func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Tokenizer_RetrieveJWKS_Call { +func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() ([]auth.PublicKeyInfo, error)) *Tokenizer_RetrieveJWKS_Call { _c.Call.Return(run) return _c } diff --git a/auth/pat.go b/auth/pat.go index 3917806ee..dcc8dc21d 100644 --- a/auth/pat.go +++ b/auth/pat.go @@ -186,6 +186,8 @@ func ParseEntityType(et string) (EntityType, error) { return UsersType, nil case DashboardsStr: return DashboardType, nil + case MessagesStr: + return MessagesType, nil default: return 0, fmt.Errorf("unknown domain entity type %s", et) } diff --git a/auth/service.go b/auth/service.go index a8323df29..b88a8efaa 100644 --- a/auth/service.go +++ b/auth/service.go @@ -12,6 +12,7 @@ import ( "github.com/absmach/supermq" "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/policies" "github.com/google/uuid" @@ -84,8 +85,8 @@ type Authn interface { // other reason, non-nil error value is returned in response. Identify(ctx context.Context, token string) (Key, error) - // RetrieveJWKS retrieves a JWKs to validate issued tokens. - RetrieveJWKS() []JWK + // RetrieveJWKS retrieves public keys to validate issued tokens. + RetrieveJWKS() []PublicKeyInfo } // Service specifies an API that must be fulfilled by the domain service @@ -201,8 +202,12 @@ func (svc service) Identify(ctx context.Context, token string) (Key, error) { } } -func (svc service) RetrieveJWKS() []JWK { - return svc.tokenizer.RetrieveJWKS() +func (svc service) RetrieveJWKS() []PublicKeyInfo { + keys, err := svc.tokenizer.RetrieveJWKS() + if err != nil { + return nil + } + return keys } func (svc service) Authorize(ctx context.Context, pr policies.Policy) error { @@ -803,6 +808,9 @@ func (svc service) authnAuthzUserPAT(ctx context.Context, token, patID string) ( _, err = svc.pats.Retrieve(ctx, key.Subject, patID) if err != nil { + if errors.Contains(err, repoerr.ErrNotFound) { + return Key{}, svcerr.ErrNotFound + } return Key{}, errors.Wrap(svcerr.ErrAuthorization, err) } diff --git a/auth/tokenizer.go b/auth/tokenizer.go deleted file mode 100644 index 5d524a782..000000000 --- a/auth/tokenizer.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package auth - -import ( - "context" -) - -// Tokenizer specifies API for encoding and decoding between string and Key. -type Tokenizer interface { - // Issue converts API Key to its string representation. - Issue(key Key) (token string, err error) - - // Parse extracts API Key data from string token. - Parse(ctx context.Context, token string) (key Key, err error) - - // RetrieveJWKS returns the JSON Web Key Set. - RetrieveJWKS() []JWK -} diff --git a/auth/tokenizer/asymmetric/README.md b/auth/tokenizer/asymmetric/README.md new file mode 100644 index 000000000..d64f96bd9 --- /dev/null +++ b/auth/tokenizer/asymmetric/README.md @@ -0,0 +1,164 @@ +# Asymmetric Tokenizer + +EdDSA (Ed25519) tokenizer with support for zero-downtime key rotation. + +## Features + +- **Single-key mode** - Simple setup with one active key +- **Two-key mode** - Active + retiring keys for _zero-downtime rotation_ +- **JWKS endpoint** - Publishes all valid public keys for token verification + +## Configuration + +The tokenizer uses environment variables to specify key file paths: + +| Environment Variable | Required | Description | +| --------------------------------- | -------- | ------------------------------------------------ | +| `SMQ_AUTH_KEYS_ACTIVE_KEY_PATH` | Yes | Path to active private key file | +| `SMQ_AUTH_KEYS_RETIRING_KEY_PATH` | No | Path to retiring private key file (for rotation) | + +Please note that key names are used as **key IDs (kid)**. + +### Single-Key Mode + +Set only the active key path: + +```bash +export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/private.key" +``` + +The tokenizer will: +- Issue new tokens signed with the active key +- Verify tokens using the active key +- Return one public key in JWKS endpoint + +### Two-Key Mode (Key Rotation) + +Set both active and retiring key paths: + +```bash +export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key" +export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key" +``` + +The tokenizer will: +- Issue new tokens signed with the active key +- Verify tokens using both active and retiring keys +- Return both public keys in JWKS endpoint + +## Key Rotation Process + +Zero-downtime key rotation in 3 simple steps: + +### 1. Generate New Key + +```bash +openssl genpkey -algorithm Ed25519 -out keys/new.key +``` + +### 2. Update Environment & Restart + +Move the current active key to retiring position and set the new key as active: + +```bash +# Before rotation +SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/current.key" +SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" # No retiring key + +# During rotation (both keys active for grace period) +SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key" +SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/current.key" + +# After rotation (restart service with new config) +docker-compose restart auth +``` + +During the grace period, tokens signed with either key remain valid. + +### 3. Clean Up After Grace Period + +After the grace period expires (typically 7-30 days), remove the retiring key: + +```bash +# Remove retiring key configuration +SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key" +SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" # Remove retiring key + +# Restart service +docker-compose restart auth + +# Delete old key file +rm keys/current.key +``` + +## Grace Period Recommendations + +**Recommended:** 168 hours (7 days) +**Minimum:** 24 hours +**Maximum:** 720 hours (30 days) + +The grace period should be longer than your longest-lived access token duration. + +## Security Best Practices + +- Store private keys with `0600` permissions +- Use cryptographically secure key generation: + ```bash + openssl genpkey -algorithm Ed25519 -out private.key + chmod 600 private.key + ``` +- Rotate keys regularly: + - Standard environments: every 90 days + - High-security environments: every 30 days +- Never commit keys to version control +- Use secrets management in production (HashiCorp Vault, AWS Secrets Manager, etc.) + +## Example: Complete Rotation + +```bash +# Day 0: Normal operation +export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2024.pem" +export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" + +# Day 1: Start rotation - generate new key +openssl genpkey -algorithm Ed25519 -out ./keys/key-2025.pem +chmod 600 ./keys/key-2025.pem + +# Day 1: Update config and restart +export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2025.pem" +export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/key-2024.pem" +docker-compose restart auth + +# Day 8: Grace period expired - remove old key +export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" +docker-compose restart auth +rm ./keys/key-2024.pem +``` + +## Troubleshooting + +### Active key not found + +``` +Error: active key file not found: ./keys/active.key +``` + +**Solution:** Ensure the file exists and path is correct. Verify `SMQ_AUTH_KEYS_ACTIVE_KEY_PATH` environment variable. + +### Retiring key warning + +If the retiring key path is set but the file is missing or invalid, the tokenizer logs a warning but continues with only the active key: + +``` +WARN: failed to load retiring key, continuing without it +``` + +This is by design - a missing retiring key won't prevent startup. + +### Invalid key format + +``` +Error: failed to parse private key +``` + +**Solution:** Ensure you're using Ed25519 keys in PEM format (PKCS8). diff --git a/auth/tokenizer/asymmetric/rotation_test.go b/auth/tokenizer/asymmetric/rotation_test.go new file mode 100644 index 000000000..64547620a --- /dev/null +++ b/auth/tokenizer/asymmetric/rotation_test.go @@ -0,0 +1,160 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package asymmetric_test + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/tokenizer/asymmetric" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type incrementingIDProvider struct { + counter int +} + +func (p *incrementingIDProvider) ID() (string, error) { + p.counter++ + return fmt.Sprintf("key-id-%d", p.counter), nil +} + +func TestTwoKeyRotation(t *testing.T) { + tmpDir := t.TempDir() + + _, activePriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + _, retiringPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + activeKeyPath := filepath.Join(tmpDir, "active.key") + retiringKeyPath := filepath.Join(tmpDir, "retiring.key") + + saveKey(t, activePriv, activeKeyPath) + saveKey(t, retiringPriv, retiringKeyPath) + + idProvider := &incrementingIDProvider{} + tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, newTestLogger()) + require.NoError(t, err) + + testKey := auth.Key{ + ID: "test-key", + Type: auth.AccessKey, + Subject: "user-123", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + Verified: true, + } + + token, err := tokenizer.Issue(testKey) + require.NoError(t, err) + assert.NotEmpty(t, token) + + verified, err := tokenizer.Parse(context.Background(), token) + require.NoError(t, err, "Should work with active token") + assert.Equal(t, testKey.Subject, verified.Subject) + + publicKeys, err := tokenizer.RetrieveJWKS() + require.NoError(t, err) + assert.Len(t, publicKeys, 2, "Should return both active and retiring keys") + + keyIDs := make(map[string]bool) + for _, pk := range publicKeys { + keyIDs[pk.KeyID] = true + } + assert.Len(t, keyIDs, 2, "Both keys should have unique IDs") +} + +func TestSingleKeyMode(t *testing.T) { + tmpDir := t.TempDir() + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + keyPath := filepath.Join(tmpDir, "single.key") + saveKey(t, privateKey, keyPath) + + idProvider := &mockIDProvider{id: "single-id"} + tokenizer, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + require.NoError(t, err) + + testKey := auth.Key{ + ID: "test", + Type: auth.AccessKey, + Subject: "user", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + } + + token, err := tokenizer.Issue(testKey) + require.NoError(t, err) + + _, err = tokenizer.Parse(context.Background(), token) + require.NoError(t, err) + + publicKeys, err := tokenizer.RetrieveJWKS() + require.NoError(t, err, "Should return one active key") + assert.Len(t, publicKeys, 1, "Should return only the active key") +} + +func TestMissingRetiringKey(t *testing.T) { + tmpDir := t.TempDir() + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + activeKeyPath := filepath.Join(tmpDir, "active.key") + saveKey(t, privateKey, activeKeyPath) + + retiringKeyPath := filepath.Join(tmpDir, "nonexistent.key") + + idProvider := &mockIDProvider{id: "test-id"} + tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, newTestLogger()) + require.NoError(t, err, "Should succeed even if retiring key is missing") + + testKey := auth.Key{ + ID: "test", + Type: auth.AccessKey, + Subject: "user", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + } + + token, err := tokenizer.Issue(testKey) + require.NoError(t, err) + + _, err = tokenizer.Parse(context.Background(), token) + require.NoError(t, err) + + publicKeys, err := tokenizer.RetrieveJWKS() + require.NoError(t, err) + assert.Len(t, publicKeys, 1, "Should return only active key when retiring key is missing") +} + +func saveKey(t *testing.T, privateKey ed25519.PrivateKey, path string) { + pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Key, + } + + err = os.WriteFile(path, pem.EncodeToMemory(pemBlock), 0o600) + require.NoError(t, err) +} diff --git a/auth/tokenizer/asymmetric/tokenizer.go b/auth/tokenizer/asymmetric/tokenizer.go new file mode 100644 index 000000000..ac8d168ec --- /dev/null +++ b/auth/tokenizer/asymmetric/tokenizer.go @@ -0,0 +1,246 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package asymmetric + +import ( + "context" + "crypto/ed25519" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "log/slog" + "os" + "path/filepath" + "strings" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/auth" + smqjwt "github.com/absmach/supermq/auth/tokenizer/util" + "github.com/absmach/supermq/pkg/errors" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +const patPrefix = "pat" + +var ( + errLoadingPrivateKey = errors.New("failed to load private key") + errDuplicateRetiringKeyID = errors.New("retiring key ID matches active key ID") + errInvalidKeySize = errors.New("invalid ED25519 key size") + errParsingPrivateKey = errors.New("failed to parse private key") + errInvalidKeyType = errors.New("private key is not ED25519") + errNoValidPublicKeys = errors.New("no valid public keys available") + errNoActiveKey = errors.New("active key not loaded") +) + +type keyPair struct { + id string + privateKey jwk.Key + publicKey jwk.Key +} + +// Tokenizer is safe for concurrent use. Keys are set during construction +// and never modified afterward. +type tokenizer struct { + activeKey *keyPair + retiringKey *keyPair // Optional, for key rotation grace period +} + +var _ auth.Tokenizer = (*tokenizer)(nil) + +// NewTokenizer creates a new asymmetric tokenizer with active and optionally retiring keys. +// activeKeyPath is required. retiringKeyPath is optional (can be empty string). +// If retiringKeyPath is provided but the file doesn't exist or is invalid, a warning is logged +// but the tokenizer is still created with just the active key. +// Key IDs are derived from filenames to ensure consistency across multiple service instances. +func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDProvider, logger *slog.Logger) (auth.Tokenizer, error) { + activeKID := keyIDFromPath(activeKeyPath) + + activePrivateJwk, activePublicJwk, err := loadKeyPair(activeKeyPath, activeKID) + if err != nil { + return nil, err + } + + mgr := &tokenizer{ + activeKey: &keyPair{ + id: activeKID, + privateKey: activePrivateJwk, + publicKey: activePublicJwk, + }, + } + + if retiringKeyPath != "" { + retiringKID := keyIDFromPath(retiringKeyPath) + if retiringKID == activeKID { + return nil, errDuplicateRetiringKeyID + } + + retiringPrivateJwk, retiringPublicJwk, err := loadKeyPair(retiringKeyPath, retiringKID) + if err != nil { + logger.Warn("failed to load retiring key, continuing without it", slog.Any("error", err)) + return mgr, nil + } + + mgr.retiringKey = &keyPair{ + id: retiringKID, + privateKey: retiringPrivateJwk, + publicKey: retiringPublicJwk, + } + logger.Info("loaded retiring key for rotation grace period", slog.String("key_id", retiringKID)) + } + + return mgr, nil +} + +func (km *tokenizer) Issue(key auth.Key) (string, error) { + if km.activeKey == nil { + return "", errNoActiveKey + } + + tkn, err := smqjwt.BuildToken(key) + if err != nil { + return "", err + } + headers := jws.NewHeaders() + if err := headers.Set(jwk.KeyIDKey, km.activeKey.id); err != nil { + return "", err + } + + signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, km.activeKey.privateKey, jws.WithProtectedHeaders(headers))) + if err != nil { + return "", err + } + + return string(signedBytes), nil +} + +func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { + if len(tokenString) >= 3 && tokenString[:3] == patPrefix { + return auth.Key{Type: auth.PersonalAccessToken}, nil + } + + set := jwk.NewSet() + if err := set.AddKey(km.activeKey.publicKey); err != nil { + return auth.Key{}, err + } + if km.retiringKey != nil { + if err := set.AddKey(km.retiringKey.publicKey); err != nil { + return auth.Key{}, err + } + } + + tkn, err := jwt.Parse( + []byte(tokenString), + jwt.WithValidate(true), + jwt.WithKeySet(set, jws.WithInferAlgorithmFromKey(true)), + ) + if err != nil { + return auth.Key{}, err + } + + if tkn.Issuer() != smqjwt.IssuerName { + return auth.Key{}, smqjwt.ErrInvalidIssuer + } + + return smqjwt.ToKey(tkn) +} + +func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { + publicKeys := make([]auth.PublicKeyInfo, 0, 2) + + if km.activeKey != nil { + if pkInfo := extractPublicKeyInfo(km.activeKey); pkInfo != nil { + publicKeys = append(publicKeys, *pkInfo) + } + } + + if km.retiringKey != nil { + if pkInfo := extractPublicKeyInfo(km.retiringKey); pkInfo != nil { + publicKeys = append(publicKeys, *pkInfo) + } + } + + if len(publicKeys) == 0 { + return nil, errNoValidPublicKeys + } + + return publicKeys, nil +} + +func extractPublicKeyInfo(kp *keyPair) *auth.PublicKeyInfo { + var rawKey ed25519.PublicKey + if err := kp.publicKey.Raw(&rawKey); err != nil { + return nil + } + + return &auth.PublicKeyInfo{ + KeyID: kp.id, + KeyType: "OKP", + Algorithm: "EdDSA", + Use: "sig", + Curve: "Ed25519", + X: base64.RawURLEncoding.EncodeToString(rawKey), + } +} + +func loadKeyPair(privateKeyPath string, kid string) (jwk.Key, jwk.Key, error) { + privateKeyBytes, err := os.ReadFile(privateKeyPath) + if err != nil { + return nil, nil, errors.Wrap(errLoadingPrivateKey, err) + } + + var privateKey ed25519.PrivateKey + block, _ := pem.Decode(privateKeyBytes) + switch { + case block != nil: + parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, nil, errors.Wrap(errParsingPrivateKey, err) + } + var ok bool + privateKey, ok = parsedKey.(ed25519.PrivateKey) + if !ok { + return nil, nil, errInvalidKeyType + } + default: + if len(privateKeyBytes) != ed25519.PrivateKeySize { + return nil, nil, errInvalidKeySize + } + privateKey = ed25519.PrivateKey(privateKeyBytes) + } + + publicKey := privateKey.Public().(ed25519.PublicKey) + + privateJwk, err := jwk.FromRaw(privateKey) + if err != nil { + return nil, nil, err + } + if err := privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil { + return nil, nil, err + } + if err := privateJwk.Set(jwk.KeyIDKey, kid); err != nil { + return nil, nil, err + } + + publicJwk, err := jwk.FromRaw(publicKey) + if err != nil { + return nil, nil, err + } + if err := publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil { + return nil, nil, err + } + if err := publicJwk.Set(jwk.KeyIDKey, kid); err != nil { + return nil, nil, err + } + + return privateJwk, publicJwk, nil +} + +func keyIDFromPath(path string) string { + base := filepath.Base(path) + ext := filepath.Ext(base) + return strings.TrimSuffix(base, ext) +} diff --git a/auth/tokenizer/asymmetric/tokenizer_test.go b/auth/tokenizer/asymmetric/tokenizer_test.go new file mode 100644 index 000000000..0fe726f76 --- /dev/null +++ b/auth/tokenizer/asymmetric/tokenizer_test.go @@ -0,0 +1,418 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package asymmetric_test + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "log/slog" + "os" + "path/filepath" + "testing" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/tokenizer/asymmetric" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockIDProvider struct { + id string +} + +func (m *mockIDProvider) ID() (string, error) { + return m.id, nil +} + +func newTestLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) +} + +func TestNewKeyManager(t *testing.T) { + idProvider := &mockIDProvider{id: "unused"} + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "private.key") + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Key, + } + + cases := []struct { + name string + setupKey func() string + expectErr bool + errContains string + }{ + { + name: "valid PEM key", + setupKey: func() string { + err := os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) + require.NoError(t, err) + return keyPath + }, + expectErr: false, + }, + { + name: "valid raw key", + setupKey: func() string { + rawKeyPath := filepath.Join(tmpDir, "raw_private.key") + err := os.WriteFile(rawKeyPath, privateKey, 0o600) + require.NoError(t, err) + return rawKeyPath + }, + expectErr: false, + }, + { + name: "non-existent key file", + setupKey: func() string { + return filepath.Join(tmpDir, "nonexistent.key") + }, + expectErr: true, + errContains: "failed to load private key", + }, + { + name: "invalid key size", + setupKey: func() string { + invalidPath := filepath.Join(tmpDir, "invalid.key") + err := os.WriteFile(invalidPath, []byte("invalid"), 0o600) + require.NoError(t, err) + return invalidPath + }, + expectErr: true, + errContains: "invalid ED25519 key size", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + path := tc.setupKey() + + km, err := asymmetric.NewTokenizer(path, "", idProvider, newTestLogger()) + + if tc.expectErr { + assert.Error(t, err) + if tc.errContains != "" { + assert.Contains(t, err.Error(), tc.errContains) + } + assert.Nil(t, km) + } else { + assert.NoError(t, err) + assert.NotNil(t, km) + } + }) + } +} + +func TestSign(t *testing.T) { + idProvider := &mockIDProvider{id: "unused"} + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "private.key") + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Key, + } + + err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) + require.NoError(t, err) + + km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + require.NoError(t, err) + + cases := []struct { + name string + key auth.Key + }{ + { + name: "sign valid key with all fields", + key: auth.Key{ + ID: "key-id", + Type: auth.AccessKey, + Issuer: "supermq.auth", + Subject: "user-id", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + Verified: true, + }, + }, + { + name: "sign key without subject", + key: auth.Key{ + ID: "key-id", + Type: auth.APIKey, + Issuer: "supermq.auth", + Role: auth.AdminRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(24 * time.Hour).UTC(), + Verified: false, + }, + }, + { + name: "sign key without ID", + key: auth.Key{ + Type: auth.AccessKey, + Subject: "user-id", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + token, err := km.Issue(tc.key) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + parts := splitJWT(token) + assert.Equal(t, 3, len(parts), "JWT should have 3 parts") + }) + } +} + +func TestVerify(t *testing.T) { + idProvider := &mockIDProvider{id: "unused"} + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "private.key") + kid := "private" + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Key, + } + + err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) + require.NoError(t, err) + + km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + require.NoError(t, err) + + validKey := auth.Key{ + ID: "key-id", + Type: auth.AccessKey, + Issuer: "supermq.auth", + Subject: "user-id", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + Verified: true, + } + + validToken, err := km.Issue(validKey) + require.NoError(t, err, "Signing a valid token should succeed") + + expiredKey := validKey + expiredKey.ExpiresAt = time.Now().Add(-1 * time.Hour).UTC() + expiredToken, err := km.Issue(expiredKey) + require.NoError(t, err, "Creating an expired token should succeed") + + wrongIssuerKey := validKey + wrongIssuerKey.Issuer = "wrong.issuer" + + privateJwk, err := jwk.FromRaw(privateKey) + require.NoError(t, err) + require.NoError(t, privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, privateJwk.Set(jwk.KeyIDKey, kid)) + + builder := jwt.NewBuilder() + builder.Issuer(wrongIssuerKey.Issuer). + Subject(wrongIssuerKey.Subject). + IssuedAt(wrongIssuerKey.IssuedAt). + Expiration(wrongIssuerKey.ExpiresAt). + JwtID(wrongIssuerKey.ID). + Claim("type", wrongIssuerKey.Type). + Claim("role", wrongIssuerKey.Role). + Claim("verified", wrongIssuerKey.Verified) + + wrongIssuerJWT, err := builder.Build() + require.NoError(t, err) + + wrongIssuerTokenBytes, err := jwt.Sign(wrongIssuerJWT, jwt.WithKey(jwa.EdDSA, privateJwk)) + require.NoError(t, err) + wrongIssuerToken := string(wrongIssuerTokenBytes) + + cases := []struct { + name string + token string + expectErr bool + errContains string + }{ + { + name: "verify valid token", + token: validToken, + expectErr: false, + }, + { + name: "verify expired token", + token: expiredToken, + expectErr: true, + errContains: "exp", + }, + { + name: "verify token with wrong issuer", + token: wrongIssuerToken, + expectErr: true, + errContains: "invalid token issuer", + }, + { + name: "verify malformed token", + token: "not.a.valid.jwt", + expectErr: true, + errContains: "", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + key, err := km.Parse(context.Background(), tc.token) + + if tc.expectErr { + assert.Error(t, err) + if tc.errContains != "" { + assert.Contains(t, err.Error(), tc.errContains) + } + } else { + assert.NoError(t, err) + assert.Equal(t, validKey.Subject, key.Subject) + assert.Equal(t, validKey.Type, key.Type) + assert.Equal(t, validKey.Role, key.Role) + } + }) + } +} + +func TestPublicKeys(t *testing.T) { + idProvider := &mockIDProvider{id: "unused"} + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "private.key") + kid := "private" + + publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Key, + } + + err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) + require.NoError(t, err) + + km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + require.NoError(t, err) + + keys, err := km.RetrieveJWKS() + assert.NoError(t, err) + assert.Len(t, keys, 1) + + key := keys[0] + assert.Equal(t, kid, key.KeyID) + assert.Equal(t, "OKP", key.KeyType) + assert.Equal(t, "EdDSA", key.Algorithm) + assert.Equal(t, "sig", key.Use) + assert.Equal(t, "Ed25519", key.Curve) + assert.NotEmpty(t, key.X) + + decoded, err := base64.RawURLEncoding.DecodeString(key.X) + assert.NoError(t, err, "The public key should be decoded") + assert.Equal(t, publicKey, ed25519.PublicKey(decoded)) +} + +func TestSignAndVerifyRoundTrip(t *testing.T) { + idProvider := &mockIDProvider{id: "unused"} + + tmpDir := t.TempDir() + keyPath := filepath.Join(tmpDir, "private.key") + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Key, + } + + err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) + require.NoError(t, err) + + km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + require.NoError(t, err) + + originalKey := auth.Key{ + ID: "key-123", + Type: auth.AccessKey, + Issuer: "supermq.auth", + Subject: "user-456", + Role: auth.UserRole, + IssuedAt: time.Now().UTC().Truncate(time.Second), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC().Truncate(time.Second), + Verified: true, + } + + token, err := km.Issue(originalKey) + require.NoError(t, err) + + verifiedKey, err := km.Parse(context.Background(), token) + require.NoError(t, err, "Verification of a valid key should succeed") + + assert.Equal(t, originalKey.ID, verifiedKey.ID) + assert.Equal(t, originalKey.Type, verifiedKey.Type) + assert.Equal(t, originalKey.Subject, verifiedKey.Subject) + assert.Equal(t, originalKey.Role, verifiedKey.Role) + assert.Equal(t, originalKey.Verified, verifiedKey.Verified) + assert.WithinDuration(t, originalKey.IssuedAt, verifiedKey.IssuedAt, time.Second) + assert.WithinDuration(t, originalKey.ExpiresAt, verifiedKey.ExpiresAt, time.Second) +} + +func splitJWT(token string) []string { + parts := []string{} + start := 0 + for i := 0; i < len(token); i++ { + if token[i] == '.' { + parts = append(parts, token[start:i]) + start = i + 1 + } + } + parts = append(parts, token[start:]) + return parts +} diff --git a/auth/tokenizer/symmetric/tokenizer.go b/auth/tokenizer/symmetric/tokenizer.go new file mode 100644 index 000000000..5933bb697 --- /dev/null +++ b/auth/tokenizer/symmetric/tokenizer.go @@ -0,0 +1,84 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package symmetric + +import ( + "context" + + "github.com/absmach/supermq/auth" + smqjwt "github.com/absmach/supermq/auth/tokenizer/util" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +const ( + patPrefix = "pat" +) + +var errJWTExpiryKey = errors.New(`"exp" not satisfied`) + +type tokenizer struct { + algorithm jwa.KeyAlgorithm + secret []byte +} + +var _ auth.Tokenizer = (*tokenizer)(nil) + +func NewTokenizer(algorithm string, secret []byte) (auth.Tokenizer, error) { + alg := jwa.KeyAlgorithmFrom(algorithm) + if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok { + return nil, auth.ErrUnsupportedKeyAlgorithm + } + if len(secret) == 0 { + return nil, auth.ErrInvalidSymmetricKey + } + return &tokenizer{ + secret: secret, + algorithm: alg, + }, nil +} + +func (km *tokenizer) Issue(key auth.Key) (string, error) { + tkn, err := smqjwt.BuildToken(key) + if err != nil { + return "", err + } + + signedBytes, err := jwt.Sign(tkn, jwt.WithKey(km.algorithm, km.secret)) + if err != nil { + return "", err + } + + return string(signedBytes), nil +} + +func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { + if len(tokenString) >= 3 && tokenString[:3] == patPrefix { + return auth.Key{Type: auth.PersonalAccessToken}, nil + } + + tkn, err := jwt.Parse( + []byte(tokenString), + jwt.WithValidate(true), + jwt.WithKey(km.algorithm, km.secret), + ) + if err != nil { + if errors.Contains(err, errJWTExpiryKey) { + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, auth.ErrExpiry) + } + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + + if tkn.Issuer() != smqjwt.IssuerName { + return auth.Key{}, smqjwt.ErrInvalidIssuer + } + + return smqjwt.ToKey(tkn) +} + +func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { + return nil, auth.ErrPublicKeysNotSupported +} diff --git a/auth/tokenizer/symmetric/tokenizer_test.go b/auth/tokenizer/symmetric/tokenizer_test.go new file mode 100644 index 000000000..0703d8d36 --- /dev/null +++ b/auth/tokenizer/symmetric/tokenizer_test.go @@ -0,0 +1,362 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package symmetric_test + +import ( + "context" + "testing" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/tokenizer/symmetric" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTokenizer(t *testing.T) { + cases := []struct { + name string + algorithm string + secret []byte + expectErr bool + errContains string + }{ + { + name: "valid HS256 algorithm", + algorithm: "HS256", + secret: []byte("my-secret-key-32-bytes-long!!"), + expectErr: false, + }, + { + name: "valid HS384 algorithm", + algorithm: "HS384", + secret: []byte("my-secret-key-48-bytes-long-for-hs384-!!!!!!"), + expectErr: false, + }, + { + name: "valid HS512 algorithm", + algorithm: "HS512", + secret: []byte("my-secret-key-64-bytes-long-for-hs512-algorithm-testing!!!!"), + expectErr: false, + }, + { + name: "invalid algorithm", + algorithm: "INVALID_ALG", + secret: []byte("my-secret-key"), + expectErr: true, + errContains: "unsupported key algorithm", + }, + { + name: "empty secret", + algorithm: "HS256", + secret: []byte{}, + expectErr: true, + errContains: "invalid symmetric key", + }, + { + name: "nil secret", + algorithm: "HS256", + secret: nil, + expectErr: true, + errContains: "invalid symmetric key", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + km, err := symmetric.NewTokenizer(tc.algorithm, tc.secret) + + if tc.expectErr { + assert.Error(t, err) + if tc.errContains != "" { + assert.Contains(t, err.Error(), tc.errContains) + } + assert.Nil(t, km) + } else { + assert.NoError(t, err) + assert.NotNil(t, km) + } + }) + } +} + +func TestSign(t *testing.T) { + secret := []byte("my-super-secret-key-for-testing") + + km, err := symmetric.NewTokenizer("HS256", secret) + require.NoError(t, err) + + cases := []struct { + name string + key auth.Key + }{ + { + name: "sign valid key with all fields", + key: auth.Key{ + ID: "key-id", + Type: auth.AccessKey, + Issuer: "supermq.auth", + Subject: "user-id", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + Verified: true, + }, + }, + { + name: "sign key without subject", + key: auth.Key{ + ID: "key-id", + Type: auth.APIKey, + Issuer: "supermq.auth", + Role: auth.AdminRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(24 * time.Hour).UTC(), + Verified: false, + }, + }, + { + name: "sign key without ID", + key: auth.Key{ + Type: auth.AccessKey, + Subject: "user-id", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + token, err := km.Issue(tc.key) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // Verify the token is valid JWT format (3 parts separated by dots) + parts := splitJWT(token) + assert.Equal(t, 3, len(parts), "JWT should have 3 parts") + }) + } +} + +func TestVerify(t *testing.T) { + secret := []byte("my-super-secret-key-for-testing") + + km, err := symmetric.NewTokenizer("HS256", secret) + require.NoError(t, err) + + validKey := auth.Key{ + ID: "key-id", + Type: auth.AccessKey, + Issuer: "supermq.auth", + Subject: "user-id", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + Verified: true, + } + + validToken, err := km.Issue(validKey) + require.NoError(t, err, "Signing valid token should succeed") + + expiredKey := validKey + expiredKey.ExpiresAt = time.Now().Add(-1 * time.Hour).UTC() + expiredToken, err := km.Issue(expiredKey) + require.NoError(t, err) + + wrongIssuerKey := validKey + wrongIssuerKey.Issuer = "wrong.issuer" + + builder := jwt.NewBuilder() + builder.Issuer(wrongIssuerKey.Issuer). + Subject(wrongIssuerKey.Subject). + IssuedAt(wrongIssuerKey.IssuedAt). + Expiration(wrongIssuerKey.ExpiresAt). + JwtID(wrongIssuerKey.ID). + Claim("type", wrongIssuerKey.Type). + Claim("role", wrongIssuerKey.Role). + Claim("verified", wrongIssuerKey.Verified) + + wrongIssuerJWT, err := builder.Build() + require.NoError(t, err) + + wrongIssuerTokenBytes, err := jwt.Sign(wrongIssuerJWT, jwt.WithKey(jwa.HS256, secret)) + require.NoError(t, err) + wrongIssuerToken := string(wrongIssuerTokenBytes) + + wrongSecretKM, err := symmetric.NewTokenizer("HS256", []byte("different-secret-key-here")) + require.NoError(t, err) + wrongSecretToken, err := wrongSecretKM.Issue(validKey) + require.NoError(t, err) + + cases := []struct { + name string + token string + expectErr bool + errContains string + }{ + { + name: "verify valid token", + token: validToken, + expectErr: false, + }, + { + name: "verify expired token", + token: expiredToken, + expectErr: true, + errContains: "exp", + }, + { + name: "verify token with wrong issuer", + token: wrongIssuerToken, + expectErr: true, + errContains: "invalid token issuer", + }, + { + name: "verify token with wrong secret", + token: wrongSecretToken, + expectErr: true, + errContains: "", + }, + { + name: "verify malformed token", + token: "not.a.valid.jwt", + expectErr: true, + errContains: "", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + key, err := km.Parse(context.Background(), tc.token) + + if tc.expectErr { + assert.Error(t, err) + if tc.errContains != "" { + assert.Contains(t, err.Error(), tc.errContains) + } + } else { + assert.NoError(t, err) + assert.Equal(t, validKey.Subject, key.Subject) + assert.Equal(t, validKey.Type, key.Type) + assert.Equal(t, validKey.Role, key.Role) + } + }) + } +} + +func TestPublicKeys(t *testing.T) { + secret := []byte("my-super-secret-key-for-testing") + + km, err := symmetric.NewTokenizer("HS256", secret) + require.NoError(t, err) + + keys, err := km.RetrieveJWKS() + assert.Error(t, err) + assert.Equal(t, auth.ErrPublicKeysNotSupported, err) + assert.Nil(t, keys) +} + +func TestSignAndVerifyRoundTrip(t *testing.T) { + algorithms := []string{"HS256", "HS384", "HS512"} + + for _, alg := range algorithms { + t.Run(alg, func(t *testing.T) { + secret := []byte("my-super-secret-key-for-testing-" + alg) + + km, err := symmetric.NewTokenizer(alg, secret) + require.NoError(t, err) + + originalKey := auth.Key{ + ID: "key-123", + Type: auth.AccessKey, + Issuer: "supermq.auth", + Subject: "user-456", + Role: auth.UserRole, + IssuedAt: time.Now().UTC().Truncate(time.Second), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC().Truncate(time.Second), + Verified: true, + } + + token, err := km.Issue(originalKey) + require.NoError(t, err) + + verifiedKey, err := km.Parse(context.Background(), token) + require.NoError(t, err) + + assert.Equal(t, originalKey.ID, verifiedKey.ID) + assert.Equal(t, originalKey.Type, verifiedKey.Type) + assert.Equal(t, originalKey.Subject, verifiedKey.Subject) + assert.Equal(t, originalKey.Role, verifiedKey.Role) + assert.Equal(t, originalKey.Verified, verifiedKey.Verified) + assert.WithinDuration(t, originalKey.IssuedAt, verifiedKey.IssuedAt, time.Second) + assert.WithinDuration(t, originalKey.ExpiresAt, verifiedKey.ExpiresAt, time.Second) + }) + } +} + +func TestDifferentAlgorithms(t *testing.T) { + secret := []byte("my-super-secret-key-for-testing-algorithms") + + key := auth.Key{ + ID: "key-id", + Type: auth.AccessKey, + Issuer: "supermq.auth", + Subject: "user-id", + Role: auth.UserRole, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), + Verified: true, + } + + km256, err := symmetric.NewTokenizer("HS256", secret) + require.NoError(t, err) + token256, err := km256.Issue(key) + require.NoError(t, err) + + km384, err := symmetric.NewTokenizer("HS384", secret) + require.NoError(t, err) + token384, err := km384.Issue(key) + require.NoError(t, err) + + km512, err := symmetric.NewTokenizer("HS512", secret) + require.NoError(t, err) + token512, err := km512.Issue(key) + require.NoError(t, err) + + assert.NotEqual(t, token256, token384) + assert.NotEqual(t, token256, token512) + assert.NotEqual(t, token384, token512) + + _, err = km256.Parse(context.Background(), token256) + assert.NoError(t, err, "verification of km256 token should pass with km256 verifier") + + _, err = km384.Parse(context.Background(), token384) + assert.NoError(t, err, "verification of km384 token should pass with km384 verifier") + + _, err = km512.Parse(context.Background(), token512) + assert.NoError(t, err, "verification of km512 token should pass with km512 verifier") + + _, err = km384.Parse(context.Background(), token256) + assert.Error(t, err, "Cross verification should fail") + + _, err = km512.Parse(context.Background(), token256) + assert.Error(t, err) +} + +func splitJWT(token string) []string { + parts := []string{} + start := 0 + for i := 0; i < len(token); i++ { + if token[i] == '.' { + parts = append(parts, token[start:i]) + start = i + 1 + } + } + parts = append(parts, token[start:]) + return parts +} diff --git a/auth/tokenizer/util/jwt.go b/auth/tokenizer/util/jwt.go new file mode 100644 index 000000000..f18c6ab2a --- /dev/null +++ b/auth/tokenizer/util/jwt.go @@ -0,0 +1,110 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package util + +import ( + "encoding/json" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +var ( + // ErrInvalidIssuer represents an invalid token issuer value. + ErrInvalidIssuer = errors.New("invalid token issuer value") + + // ErrJSONHandle indicates an error in handling JSON. + ErrJSONHandle = errors.New("failed to perform operation JSON") + + errInvalidType = errors.New("invalid token type") + errInvalidRole = errors.New("invalid role") + errInvalidVerified = errors.New("invalid verified") +) + +const ( + IssuerName = "supermq.auth" + TokenType = "type" + RoleField = "role" + VerifiedField = "verified" +) + +// ToKey converts a JWT token to an auth.Key by extracting claims. +func ToKey(tkn jwt.Token) (auth.Key, error) { + data, err := json.Marshal(tkn.PrivateClaims()) + if err != nil { + return auth.Key{}, errors.Wrap(ErrJSONHandle, err) + } + var key auth.Key + if err := json.Unmarshal(data, &key); err != nil { + return auth.Key{}, errors.Wrap(ErrJSONHandle, err) + } + + tType, ok := tkn.Get(TokenType) + if !ok { + return auth.Key{}, errInvalidType + } + kType, ok := tType.(float64) + if !ok { + return auth.Key{}, errInvalidType + } + kt := auth.KeyType(kType) + if !kt.Validate() { + return auth.Key{}, errInvalidType + } + + tRole, ok := tkn.Get(RoleField) + if !ok { + return auth.Key{}, errInvalidRole + } + kRole, ok := tRole.(float64) + if !ok { + return auth.Key{}, errInvalidRole + } + + tVerified, ok := tkn.Get(VerifiedField) + if !ok { + return auth.Key{}, errInvalidVerified + } + kVerified, ok := tVerified.(bool) + if !ok { + return auth.Key{}, errInvalidVerified + } + + kr := auth.Role(kRole) + if !kr.Validate() { + return auth.Key{}, errInvalidRole + } + + key.ID = tkn.JwtID() + key.Type = auth.KeyType(kType) + key.Role = auth.Role(kRole) + key.Issuer = tkn.Issuer() + key.Subject = tkn.Subject() + key.IssuedAt = tkn.IssuedAt() + key.ExpiresAt = tkn.Expiration() + key.Verified = kVerified + + return key, nil +} + +func BuildToken(key auth.Key) (jwt.Token, error) { + builder := jwt.NewBuilder() + builder. + Issuer(IssuerName). + IssuedAt(key.IssuedAt). + Claim(TokenType, key.Type). + Expiration(key.ExpiresAt). + Claim(RoleField, key.Role). + Claim(VerifiedField, key.Verified) + + if key.Subject != "" { + builder.Subject(key.Subject) + } + if key.ID != "" { + builder.JwtID(key.ID) + } + + return builder.Build() +} diff --git a/channels/channels.go b/channels/channels.go index 716513e0b..146962797 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -136,8 +136,7 @@ type Service interface { // ChannelRepository specifies a channel persistence API. type Repository interface { // Save persists multiple channels. Channels are saved using a transaction. If one channel - // fails then none will be saved. Successful operation is indicated by non-nil - // error response. + // fails then none will be saved. Successful operation is indicated by non-nil error response. Save(ctx context.Context, chs ...Channel) ([]Channel, error) // Update performs an update to the existing channel. diff --git a/cmd/auth/main.go b/cmd/auth/main.go index 69a94c983..d2e36d71c 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -22,11 +22,10 @@ import ( httpapi "github.com/absmach/supermq/auth/api/http" "github.com/absmach/supermq/auth/cache" "github.com/absmach/supermq/auth/hasher" - "github.com/absmach/supermq/auth/jwt" - "github.com/absmach/supermq/auth/keymanager/asymmetric" - "github.com/absmach/supermq/auth/keymanager/symmetric" "github.com/absmach/supermq/auth/middleware" apostgres "github.com/absmach/supermq/auth/postgres" + "github.com/absmach/supermq/auth/tokenizer/asymmetric" + "github.com/absmach/supermq/auth/tokenizer/symmetric" redisclient "github.com/absmach/supermq/internal/clients/redis" smqlog "github.com/absmach/supermq/logger" "github.com/absmach/supermq/pkg/jaeger" @@ -69,7 +68,8 @@ type config struct { AccessDuration time.Duration `env:"SMQ_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"` RefreshDuration time.Duration `env:"SMQ_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"` KeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"EdDSA"` - PrivateKeyPath string `env:"SMQ_AUTH_KEYS_PRIVATE_KEY_PATH" envDefault:"./ssl/keys/private.pem"` + ActiveKeyPath string `env:"SMQ_AUTH_KEYS_ACTIVE_KEY_PATH" envDefault:"./keys/active.key"` + RetiringKeyPath string `env:"SMQ_AUTH_KEYS_RETIRING_KEY_PATH" envDefault:""` InvitationDuration time.Duration `env:"SMQ_AUTH_INVITATION_DURATION" envDefault:"168h"` SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` @@ -159,17 +159,23 @@ func main() { idProvider := uuid.New() - var keyManager auth.KeyManager + if err := validateKeyConfig(isSymmetric, cfg, logger); err != nil { + logger.Error(fmt.Sprintf("invalid key configuration: %s", err.Error())) + exitCode = 1 + return + } + + var tokenizer auth.Tokenizer switch { case isSymmetric: - keyManager, err = symmetric.NewKeyManager(cfg.KeyAlgorithm, []byte(cfg.SecretKey)) + tokenizer, err = symmetric.NewTokenizer(cfg.KeyAlgorithm, []byte(cfg.SecretKey)) if err != nil { logger.Error(fmt.Sprintf("failed to create symmetric key manager: %s", err.Error())) exitCode = 1 return } default: - keyManager, err = asymmetric.NewKeyManager(cfg.PrivateKeyPath, idProvider) + tokenizer, err = asymmetric.NewTokenizer(cfg.ActiveKeyPath, cfg.RetiringKeyPath, idProvider, logger) if err != nil { logger.Error(fmt.Sprintf("failed to create asymmetric key manager: %s", err.Error())) exitCode = 1 @@ -177,7 +183,7 @@ func main() { } } - svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration, keyManager, idProvider) + svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration, tokenizer, idProvider) if err != nil { logger.Error(fmt.Sprintf("failed to create service : %s\n", err.Error())) exitCode = 1 @@ -258,7 +264,34 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch return nil } -func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, keyManager auth.KeyManager, idProvider supermq.IDProvider) (auth.Service, error) { +func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error { + if isSymmetric { + if cfg.SecretKey == "secret" { + return fmt.Errorf("default secret key is insecure - please set SMQ_AUTH_SECRET_KEY environment variable") + } + return nil + } + + // Validate active key path + _, err := os.Stat(cfg.ActiveKeyPath) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("active key file not found: %s - please set SMQ_AUTH_KEYS_ACTIVE_KEY_PATH", cfg.ActiveKeyPath) + } + return fmt.Errorf("failed to access active key file: %w", err) + } + + // Retiring key is optional - only validate if path is provided + if cfg.RetiringKeyPath != "" { + if _, err := os.Stat(cfg.RetiringKeyPath); err != nil { + l.Warn("retiring key path provided but file not accessible", slog.Any("error", err)) + } + } + + return nil +} + +func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, tokenizer auth.Tokenizer, idProvider supermq.IDProvider) (auth.Service, error) { cache := cache.NewPatsCache(cacheClient, keyDuration) database := pgclient.NewDatabase(db, dbConfig, tracer) @@ -269,8 +302,6 @@ func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient. pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger) pService := spicedb.NewPolicyService(spicedbClient, logger) - tokenizer := jwt.New(keyManager) - svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) svc = middleware.NewLogging(svc, logger) counter, latency := prometheus.MakeMetrics("auth", "api") diff --git a/docker/.env b/docker/.env index c879ba545..9d682d7f0 100644 --- a/docker/.env +++ b/docker/.env @@ -101,7 +101,8 @@ SMQ_AUTH_DB_SSL_ROOT_CERT= SMQ_AUTH_ACCESS_TOKEN_DURATION="1h" SMQ_AUTH_REFRESH_TOKEN_DURATION="24h" SMQ_AUTH_KEYS_ALGORITHM="EdDSA" -SMQ_AUTH_KEYS_PRIVATE_KEY_PATH="./ssl/keys/private.key" +SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key" +SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key" SMQ_AUTH_INVITATION_DURATION="168h" SMQ_AUTH_ADAPTER_INSTANCE_ID= SMQ_AUTH_CACHE_URL=redis://auth-redis:${SMQ_REDIS_TCP_PORT}/0 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1df5bf444..bd0a45d49 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -121,7 +121,8 @@ services: SMQ_AUTH_ACCESS_TOKEN_DURATION: ${SMQ_AUTH_ACCESS_TOKEN_DURATION} SMQ_AUTH_REFRESH_TOKEN_DURATION: ${SMQ_AUTH_REFRESH_TOKEN_DURATION} SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_AUTH_KEYS_PRIVATE_KEY_PATH: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:+/keys/private.key} + SMQ_AUTH_KEYS_ACTIVE_KEY_PATH: ${SMQ_AUTH_KEYS_ACTIVE_KEY_PATH:+/keys/active.key} + SMQ_AUTH_KEYS_RETIRING_KEY_PATH: ${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:+/keys/retiring.key} ## Compose supports parameter expansion in environment, ## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty ## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default @@ -152,11 +153,16 @@ services: volumes: - ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE} - supermq-pat-db-volume:/supermq-data - - supermq-auth-keys-volume:/keys - # Auth private key file + # Auth active private key file - type: bind - source: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:-ssl/certs/dummy/private_key} - target: /keys/private.key + source: ${SMQ_AUTH_KEYS_ACTIVE_KEY_PATH} + target: /keys/active.key + read_only: true + # Auth retiring private key file (optional, for key rotation) + - type: bind + source: ${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:-ssl/certs/dummy/retiring_key} + target: /keys/retiring${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:+.key} + read_only: true bind: create_host_path: true # Auth gRPC mTLS server certificates diff --git a/docker/ssl/keys/private.key b/docker/keys/active.key similarity index 100% rename from docker/ssl/keys/private.key rename to docker/keys/active.key diff --git a/docker/keys/retiring.key b/docker/keys/retiring.key new file mode 100644 index 000000000..89426b698 --- /dev/null +++ b/docker/keys/retiring.key @@ -0,0 +1,3 @@ +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIE9Qu5lN6KOfdO14XJUClM1UPrqT55BczLMcRuSG7Ziy +-----END PRIVATE KEY----- diff --git a/pkg/authn/jwks/authn.go b/pkg/authn/jwks/authn.go index 037ce88b1..96605155b 100644 --- a/pkg/authn/jwks/authn.go +++ b/pkg/authn/jwks/authn.go @@ -5,6 +5,7 @@ package jwks import ( "context" + "fmt" "io" "net/http" "strings" @@ -14,7 +15,7 @@ import ( grpcAuthV1 "github.com/absmach/supermq/api/grpc/auth/v1" smqauth "github.com/absmach/supermq/auth" "github.com/absmach/supermq/auth/api/grpc/auth" - smqjwt "github.com/absmach/supermq/auth/jwt" + smqjwt "github.com/absmach/supermq/auth/tokenizer/util" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" @@ -26,8 +27,13 @@ import ( ) const ( - issuerName = "supermq.auth" - cacheDuration = 5 * time.Minute + issuerName = "supermq.auth" + acceptHeader = "application/json" + + fetchKeyDeadline = 10 * time.Second + cacheDuration = 5 * time.Minute + + errorBodyBytes = 1024 ) var ( @@ -39,12 +45,6 @@ var ( errInvalidIssuer = errors.New("invalid token issuer value") // ErrValidateJWTToken indicates a failure to validate JWT token. errValidateJWTToken = errors.New("failed to validate jwt token") - - jwksCache = struct { - sync.RWMutex - jwks jwk.Set - cachedAt time.Time - }{} ) var _ authn.Authentication = (*authentication)(nil) @@ -52,6 +52,14 @@ var _ authn.Authentication = (*authentication)(nil) type authentication struct { jwksURL string authSvcClient grpcAuthV1.AuthServiceClient + httpClient *http.Client + cache *jwksCache +} + +type jwksCache struct { + mu sync.RWMutex + jwks jwk.Set + cachedAt time.Time } func NewAuthentication(ctx context.Context, jwksURL string, cfg grpcclient.Config) (authn.Authentication, grpcclient.Handler, error) { @@ -69,9 +77,13 @@ func NewAuthentication(ctx context.Context, jwksURL string, cfg grpcclient.Confi } authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout) + httpClient := &http.Client{} + return authentication{ jwksURL: jwksURL, authSvcClient: authSvcClient, + httpClient: httpClient, + cache: &jwksCache{}, }, client, nil } @@ -84,14 +96,26 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S return authn.Session{Type: authn.PersonalAccessToken, PatID: res.GetId(), UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil } - jwks, err := a.fetchJWKS(ctx) + jwks, err := a.fetchJWKS(ctx, false) if err != nil { return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err) } + tkn, err := validateToken(token, jwks) if err != nil { - return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err) + // If signature verification failed, try force with refresh JWKS (key rotation scenario) + if isSignatureError(err) { + jwks, fetchErr := a.fetchJWKS(ctx, true) + if fetchErr == nil { + tkn, err = validateToken(token, jwks) + } + } + + if err != nil { + return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err) + } } + key, err := smqjwt.ToKey(tkn) if err != nil { return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err) @@ -105,45 +129,63 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S }, nil } -func (a authentication) fetchJWKS(ctx context.Context) (jwk.Set, error) { - jwksCache.RLock() - if time.Since(jwksCache.cachedAt) < cacheDuration && jwksCache.jwks.Len() > 0 { - cached := jwksCache.jwks - jwksCache.RUnlock() - return cached, nil - } - jwksCache.RUnlock() +func isSignatureError(err error) bool { + return !errors.Contains(err, errJWTExpiryKey) && + !errors.Contains(err, errInvalidIssuer) && + !errors.Contains(err, smqauth.ErrExpiry) +} - req, err := http.NewRequestWithContext(ctx, "GET", a.jwksURL, nil) - if err != nil { - return nil, err +func (a authentication) fetchJWKS(ctx context.Context, forceRefresh bool) (jwk.Set, error) { + if !forceRefresh { + a.cache.mu.RLock() + if time.Since(a.cache.cachedAt) < cacheDuration && a.cache.jwks.Len() > 0 { + cached := a.cache.jwks + a.cache.mu.RUnlock() + return cached, nil + } + a.cache.mu.RUnlock() } - req.Header.Set("Accept", "application/json") - httpClient := &http.Client{ - Timeout: 10 * time.Second, + fetchCtx := ctx + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + fetchCtx, cancel = context.WithTimeout(ctx, fetchKeyDeadline) + defer cancel() } - resp, err := httpClient.Do(req) + + // Fetch fresh JWKS from auth service + req, err := http.NewRequestWithContext(fetchCtx, http.MethodGet, a.jwksURL, nil) if err != nil { - return nil, err + return nil, errors.Wrap(errFetchJWKS, err) + } + req.Header.Set("Accept", acceptHeader) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, errors.Wrap(errFetchJWKS, err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { - return nil, errFetchJWKS + // Read error body for better diagnostics + body, _ := io.ReadAll(io.LimitReader(resp.Body, errorBodyBytes)) + return nil, errors.Wrap(errFetchJWKS, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))) } data, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, errors.Wrap(errFetchJWKS, err) } + set, err := jwk.Parse(data) if err != nil { - return nil, err + return nil, errors.Wrap(errFetchJWKS, err) } - jwksCache.Lock() - jwksCache.jwks = set - jwksCache.cachedAt = time.Now() - jwksCache.Unlock() + + a.cache.mu.Lock() + a.cache.jwks = set + a.cache.cachedAt = time.Now() + a.cache.mu.Unlock() return set, nil } @@ -160,6 +202,8 @@ func validateToken(token string, jwks jwk.Set) (jwt.Token, error) { } return nil, err } + + // Validate issuer validator := jwt.ValidatorFunc(func(_ context.Context, t jwt.Token) jwt.ValidationError { if t.Issuer() != issuerName { return jwt.NewValidationError(errInvalidIssuer) diff --git a/pkg/authn/jwks/authn_test.go b/pkg/authn/jwks/authn_test.go new file mode 100644 index 000000000..dc6c1da32 --- /dev/null +++ b/pkg/authn/jwks/authn_test.go @@ -0,0 +1,336 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package jwks_test + +import ( + "crypto/ed25519" + "crypto/rand" + "testing" + "time" + + "github.com/absmach/supermq/auth" + smqjwt "github.com/absmach/supermq/auth/tokenizer/util" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + validIssuer = "supermq.auth" + invalidIssuer = "invalid.issuer" + userID = "user123" +) + +func TestValidateToken(t *testing.T) { + publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + kid := "test-key-id" + + privateJwk, err := jwk.FromRaw(privateKey) + require.NoError(t, err, "Parsing JWKS private key should work") + require.NoError(t, privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, privateJwk.Set(jwk.KeyIDKey, kid)) + + publicJwk, err := jwk.FromRaw(publicKey) + require.NoError(t, err, "Parsing JWKS public key should work") + require.NoError(t, publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, publicJwk.Set(jwk.KeyIDKey, kid)) + + jwksSet := jwk.NewSet() + require.NoError(t, jwksSet.AddKey(publicJwk), "Creation of JWKS should succeed") + + cases := []struct { + name string + issuer string + expiry time.Time + expectErr bool + }{ + { + name: "valid token", + issuer: validIssuer, + expiry: time.Now().Add(1 * time.Hour), + expectErr: false, + }, + { + name: "expired token", + issuer: validIssuer, + expiry: time.Now().Add(-1 * time.Hour), + expectErr: true, + }, + { + name: "invalid issuer", + issuer: invalidIssuer, + expiry: time.Now().Add(1 * time.Hour), + expectErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + builder := jwt.NewBuilder() + builder.Issuer(tc.issuer) + builder.IssuedAt(time.Now()) + builder.Expiration(tc.expiry) + builder.Subject(userID) + builder.Claim(smqjwt.TokenType, auth.AccessKey) + builder.Claim(smqjwt.RoleField, auth.UserRole) + builder.Claim(smqjwt.VerifiedField, true) + + token, err := builder.Build() + require.NoError(t, err) + + signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk)) + require.NoError(t, err) + + parsedToken, err := jwt.Parse( + signedToken, + jwt.WithValidate(true), + jwt.WithKeySet(jwksSet), + ) + + if tc.expectErr { + // For expired token, parsing will fail + // For invalid issuer, parsing succeeds but we need to check issuer + if tc.issuer != validIssuer && err == nil { + if parsedToken.Issuer() != validIssuer { + assert.NotEqual(t, validIssuer, parsedToken.Issuer()) + } else { + assert.Fail(t, "Expected invalid issuer error") + } + } else { + assert.Error(t, err) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, parsedToken) + assert.Equal(t, tc.issuer, parsedToken.Issuer()) + } + }) + } +} + +func TestMultiKeyJWKS(t *testing.T) { + pub1, priv1, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pub2, priv2, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + activeKID := "key-active" + retiringKID := "key-retiring" + + privateJwk1, err := jwk.FromRaw(priv1) + require.NoError(t, err, "Parsing JWKS private key 1 should work") + require.NoError(t, privateJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, privateJwk1.Set(jwk.KeyIDKey, activeKID)) + + publicJwk1, err := jwk.FromRaw(pub1) + require.NoError(t, err, "Parsing JWKS public key 1 should work") + require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, activeKID)) + + privateJwk2, err := jwk.FromRaw(priv2) + require.NoError(t, err, "Parsing JWKS private key 2 should work") + require.NoError(t, privateJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, privateJwk2.Set(jwk.KeyIDKey, retiringKID)) + + publicJwk2, err := jwk.FromRaw(pub2) + require.NoError(t, err, "Parsing JWKS public key 2 should work") + require.NoError(t, publicJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, publicJwk2.Set(jwk.KeyIDKey, retiringKID)) + + // Create JWKS set with both keys (simulating rotation period) + jwksSet := jwk.NewSet() + require.NoError(t, jwksSet.AddKey(publicJwk1)) + require.NoError(t, jwksSet.AddKey(publicJwk2)) + + assert.Equal(t, 2, jwksSet.Len(), "JWKS should contain both keys") + + cases := []struct { + name string + privateKey jwk.Key + kid string + issuer string + expiry time.Time + expectErr bool + }{ + { + name: "token signed with active key", + privateKey: privateJwk1, + kid: activeKID, + issuer: validIssuer, + expiry: time.Now().Add(1 * time.Hour), + expectErr: false, + }, + { + name: "token signed with retiring key", + privateKey: privateJwk2, + kid: retiringKID, + issuer: validIssuer, + expiry: time.Now().Add(1 * time.Hour), + expectErr: false, + }, + { + name: "expired token with active key", + privateKey: privateJwk1, + kid: activeKID, + issuer: validIssuer, + expiry: time.Now().Add(-1 * time.Hour), + expectErr: true, + }, + { + name: "expired token with retiring key", + privateKey: privateJwk2, + kid: retiringKID, + issuer: validIssuer, + expiry: time.Now().Add(-1 * time.Hour), + expectErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + builder := jwt.NewBuilder() + builder.Issuer(tc.issuer) + builder.IssuedAt(time.Now()) + builder.Expiration(tc.expiry) + builder.Subject(userID) + builder.Claim(smqjwt.TokenType, auth.AccessKey) + builder.Claim(smqjwt.RoleField, auth.UserRole) + builder.Claim(smqjwt.VerifiedField, true) + + token, err := builder.Build() + require.NoError(t, err) + + signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, tc.privateKey)) + require.NoError(t, err) + + // Parse the token using the multi-key JWKS set + parsedToken, err := jwt.Parse( + signedToken, + jwt.WithValidate(true), + jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)), + ) + + if tc.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, parsedToken) + assert.Equal(t, tc.issuer, parsedToken.Issuer()) + assert.Equal(t, userID, parsedToken.Subject()) + } + }) + } +} + +func TestKeyIDMatching(t *testing.T) { + pub1, priv1, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err, "Generating key par 1 should succeed") + + pub2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err, "Generating key par 2 should succeed") + + kid1 := "key-1" + kid2 := "key-2" + + privateJwk1, err := jwk.FromRaw(priv1) + require.NoError(t, err, "Parsing JWKS private key 1 should work") + require.NoError(t, privateJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, privateJwk1.Set(jwk.KeyIDKey, kid1)) + + publicJwk1, err := jwk.FromRaw(pub1) + require.NoError(t, err, "Parsing JWKS public key 1 should work") + require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, kid1)) + + publicJwk2, err := jwk.FromRaw(pub2) + require.NoError(t, err, "Parsing JWKS public key 2 should work") + require.NoError(t, publicJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, publicJwk2.Set(jwk.KeyIDKey, kid2)) + + jwksSet := jwk.NewSet() + require.NoError(t, jwksSet.AddKey(publicJwk1), "Adding public key 1 to JWKS set should succeed") + require.NoError(t, jwksSet.AddKey(publicJwk2), "Adding public key 2 to JWKS set should succeed") + + // Create token signed with key-1 + builder := jwt.NewBuilder() + builder.Issuer(validIssuer) + builder.IssuedAt(time.Now()) + builder.Expiration(time.Now().Add(1 * time.Hour)) + builder.Subject(userID) + + token, err := builder.Build() + require.NoError(t, err) + + signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk1)) + require.NoError(t, err) + + parsedToken, err := jwt.Parse( + signedToken, + jwt.WithValidate(true), + jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)), + ) + + require.NoError(t, err) + assert.NotNil(t, parsedToken) + assert.Equal(t, validIssuer, parsedToken.Issuer()) + + headers, err := jws.Parse(signedToken) + require.NoError(t, err) + sigs := headers.Signatures() + require.Len(t, sigs, 1, "JWT should have exactly one signature") + + kidValue := sigs[0].ProtectedHeaders().KeyID() + assert.Equal(t, kid1, kidValue, "JWT kid header should match signing key") +} + +func TestUnknownKeyID(t *testing.T) { + pub1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err, "Generating public key 1 should succeed") + + _, priv2, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err, "Generating private key 2 should succeed") + + kid1 := "known-key" + kid2 := "unknown-key" + + publicJwk1, err := jwk.FromRaw(pub1) + require.NoError(t, err, "Parsing JWKS public key 1 should work") + require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, kid1)) + + privateJwk2, err := jwk.FromRaw(priv2) + require.NoError(t, err, "Parsing JWKS private key 2 should work") + require.NoError(t, privateJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA)) + require.NoError(t, privateJwk2.Set(jwk.KeyIDKey, kid2)) + + jwksSet := jwk.NewSet() + require.NoError(t, jwksSet.AddKey(publicJwk1), "Adding public key 1 to set should succeed") + + // Create token signed with unknown key-2 + builder := jwt.NewBuilder() + builder.Issuer(validIssuer) + builder.IssuedAt(time.Now()) + builder.Expiration(time.Now().Add(1 * time.Hour)) + builder.Subject(userID) + + token, err := builder.Build() + require.NoError(t, err) + + signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk2)) + require.NoError(t, err) + + _, err = jwt.Parse( + signedToken, + jwt.WithValidate(true), + jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)), + ) + + assert.Error(t, err, "Should fail when token's kid is not in JWKS") +} diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index 3e4e5415e..c62d33158 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -58,7 +58,6 @@ packages: Cache: Hasher: KeyRepository: - KeyManager: Tokenizer: PATS: PATSRepository: