NOISSUE - Improve JWKS (#3301)

Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
Dušan Borovčanin
2025-12-26 18:15:12 +01:00
committed by GitHub
parent 6b6bab79c6
commit 52510d8c62
32 changed files with 2111 additions and 1063 deletions
+12 -18
View File
@@ -4,8 +4,6 @@
package keys_test
import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"io"
@@ -21,11 +19,8 @@ import (
"github.com/absmach/supermq/auth/mocks"
smqlog "github.com/absmach/supermq/logger"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
@@ -323,17 +318,17 @@ func TestRetrieveJWKS(t *testing.T) {
cases := []struct {
desc string
svcRes []auth.JWK
svcRes []auth.PublicKeyInfo
status int
}{
{
desc: "retrieve JWKS with keys",
svcRes: []auth.JWK{newJWK(t), newJWK(t)},
svcRes: []auth.PublicKeyInfo{newPublicKeyInfo(), newPublicKeyInfo()},
status: http.StatusOK,
},
{
desc: "retrieve empty JWKS",
svcRes: []auth.JWK{},
svcRes: []auth.PublicKeyInfo{},
status: http.StatusOK,
},
}
@@ -352,14 +347,13 @@ func TestRetrieveJWKS(t *testing.T) {
}
}
func newJWK(t *testing.T) auth.JWK {
pKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err))
jwkKey, err := jwk.FromRaw(&pKey.PublicKey)
require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err))
err = jwkKey.Set(jwk.KeyIDKey, "test-key-id")
require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err))
err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String())
require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err))
return auth.NewJWK(jwkKey)
func newPublicKeyInfo() auth.PublicKeyInfo {
return auth.PublicKeyInfo{
KeyID: "test-key-id",
KeyType: "OKP",
Algorithm: "EdDSA",
Use: "sig",
Curve: "Ed25519",
X: "base64url-encoded-public-key",
}
}
+6 -10
View File
@@ -11,7 +11,6 @@ import (
"github.com/absmach/supermq"
"github.com/absmach/supermq/auth"
"github.com/lestrrat-go/jwx/v2/jwk"
)
var (
@@ -76,20 +75,17 @@ func (res revokeKeyRes) Empty() bool {
}
type retrieveJWKSRes struct {
Keys []auth.JWK `json:"-"`
CacheMaxAge int `json:"-"`
CacheStaleWhileRevalidate int `json:"-"`
Keys []auth.PublicKeyInfo `json:"-"`
CacheMaxAge int `json:"-"`
CacheStaleWhileRevalidate int `json:"-"`
}
func (res retrieveJWKSRes) MarshalJSON() ([]byte, error) {
set := jwk.NewSet()
for _, k := range res.Keys {
if err := set.AddKey(k.Key()); err != nil {
return nil, err
}
type jwksResponse struct {
Keys []auth.PublicKeyInfo `json:"keys"`
}
return json.Marshal(set)
return json.Marshal(jwksResponse{Keys: res.Keys})
}
func (res retrieveJWKSRes) Code() int {
-341
View File
@@ -1,341 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package jwt_test
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
"testing"
"time"
"github.com/absmach/supermq/auth"
authjwt "github.com/absmach/supermq/auth/jwt"
"github.com/absmach/supermq/auth/mocks"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
tokenType = "type"
roleField = "role"
VerifiedField = "verified"
issuerName = "supermq.auth"
)
var (
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
keyManager = new(mocks.KeyManager)
)
func TestIssue(t *testing.T) {
tokenizer := authjwt.New(keyManager)
validKey := key()
signedToken, _, err := signToken(issuerName, validKey, false)
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
cases := []struct {
desc string
key auth.Key
managerReq jwt.Token
managerResp []byte
managerErr error
err error
}{
{
desc: "issue new token",
key: validKey,
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token with OAuth token",
key: auth.Key{
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
},
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token without a domain",
key: auth.Key{
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token without a subject",
key: auth.Key{
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: "",
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token without type",
key: auth.Key{
ID: testsutil.GenerateUUID(t),
Type: auth.KeyType(auth.InvitationKey + 1),
Subject: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token without a domain and subject",
key: auth.Key{
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: "",
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
},
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token with failed to sign jwt",
key: validKey,
managerErr: svcerr.ErrAuthentication,
err: authjwt.ErrSignJWT,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
tc.managerReq = newToken(issuerName, tc.key)
kmCall := keyManager.On("SignJWT", tc.managerReq).Return(tc.managerResp, tc.managerErr)
tkn, err := tokenizer.Issue(tc.key)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, tkn, fmt.Sprintf("%s expected token, got empty string", tc.desc))
}
kmCall.Unset()
})
}
}
func TestParse(t *testing.T) {
tokenizer := authjwt.New(keyManager)
validKey := key()
signedTkn, parsedTkn, err := signToken(issuerName, validKey, true)
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
apiKey := key()
apiKey.Type = auth.APIKey
apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second)
apiToken, _, err := signToken(issuerName, apiKey, false)
require.Nil(t, err, fmt.Sprintf("issuing api key expected to succeed: %s", err))
expKey := key()
expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second)
expToken, _, err := signToken(issuerName, expKey, false)
require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err))
emptySubjectKey := key()
emptySubjectKey.Subject = ""
signedEmptySubjectTkn, parsedEmptySubjectTkn, err := signToken(issuerName, emptySubjectKey, true)
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
emptyTypeKey := key()
emptyTypeKey.Type = auth.KeyType(auth.InvitationKey + 1)
emptyTypeToken, _, err := signToken(issuerName, emptyTypeKey, false)
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
emptyKey := key()
emptyKey.Subject = ""
signedInValidTkn, parsedInvalidTkn, err := signToken("invalid.issuer", key(), true)
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
cases := []struct {
desc string
key auth.Key
token string
managerRes jwt.Token
managerErr error
err error
}{
{
desc: "parse valid key",
key: validKey,
token: signedTkn,
managerRes: parsedTkn,
err: nil,
},
{
desc: "parse invalid key",
key: auth.Key{},
token: "invalid",
managerErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "parse expired key",
key: auth.Key{},
token: expToken,
managerErr: errJWTExpiryKey,
err: auth.ErrExpiry,
},
{
desc: "parse expired API key",
key: apiKey,
token: apiToken,
managerErr: errJWTExpiryKey,
err: auth.ErrExpiry,
},
{
desc: "parse token with invalid issuer",
key: auth.Key{},
token: signedInValidTkn,
managerRes: parsedInvalidTkn,
err: svcerr.ErrAuthentication,
},
{
desc: "parse token with empty subject",
key: emptySubjectKey,
token: signedEmptySubjectTkn,
managerRes: parsedEmptySubjectTkn,
err: nil,
},
{
desc: "parse token with empty type",
key: emptyTypeKey,
token: emptyTypeToken,
managerRes: newToken(issuerName, emptyKey),
err: svcerr.ErrAuthentication,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
kmCall := keyManager.On("ParseJWT", tc.token).Return(tc.managerRes, tc.managerErr)
key, err := tokenizer.Parse(context.Background(), tc.token)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key))
}
kmCall.Unset()
})
}
}
func TestRetrieveJWKS(t *testing.T) {
tokenizer := authjwt.New(keyManager)
cases := []struct {
desc string
keys []auth.JWK
retrieveErr error
err error
}{
{
desc: "retrieve jwks with keys",
keys: []auth.JWK{newJWK(t), newJWK(t)},
},
{
desc: "retrieve empty jwks",
keys: []auth.JWK{},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
kmCall := keyManager.On("PublicJWKS", mock.Anything).Return(tc.keys, tc.retrieveErr)
jwks := tokenizer.RetrieveJWKS()
assert.Equal(t, tc.keys, jwks, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.keys, jwks))
kmCall.Unset()
})
}
}
func key() auth.Key {
exp := time.Now().UTC().Add(10 * time.Minute).Round(time.Second)
return auth.Key{
ID: "66af4a67-3823-438a-abd7-efdb613eaef6",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Role: auth.UserRole,
Subject: "66af4a67-3823-438a-abd7-efdb613eaef6",
IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: exp,
}
}
func newToken(issuerName string, key auth.Key) jwt.Token {
builder := jwt.NewBuilder()
builder.
Issuer(issuerName).
IssuedAt(key.IssuedAt).
Claim(tokenType, key.Type).
Expiration(key.ExpiresAt)
builder.Claim(roleField, key.Role)
builder.Claim(VerifiedField, key.Verified)
if key.Subject != "" {
builder.Subject(key.Subject)
}
if key.ID != "" {
builder.JwtID(key.ID)
}
tkn, _ := builder.Build()
return tkn
}
func signToken(issuerName string, key auth.Key, parseToken bool) (string, jwt.Token, error) {
tkn := newToken(issuerName, key)
pKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
return "", nil, err
}
pubKey := &pKey.PublicKey
sTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.RS256, pKey))
if err != nil {
return "", nil, err
}
if !parseToken {
return string(sTkn), nil, nil
}
pTkn, err := jwt.Parse(
sTkn,
jwt.WithValidate(true),
jwt.WithKey(jwa.RS256, pubKey),
)
if err != nil {
return "", nil, err
}
return string(sTkn), pTkn, nil
}
func newJWK(t *testing.T) auth.JWK {
pKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err))
jwkKey, err := jwk.FromRaw(&pKey.PublicKey)
require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err))
err = jwkKey.Set(jwk.KeyIDKey, "test-key-id")
require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err))
err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String())
require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err))
return auth.NewJWK(jwkKey)
}
-182
View File
@@ -1,182 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"context"
"encoding/json"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/lestrrat-go/jwx/v2/jwt"
)
var (
// errInvalidIssuer is returned when the issuer is not supermq.auth.
errInvalidIssuer = errors.New("invalid token issuer value")
// errInvalidType is returned when there is no type field.
errInvalidType = errors.New("invalid token type")
// errInvalidRole is returned when the role is invalid.
errInvalidRole = errors.New("invalid role")
// errInvalidVerified is returned when the verified is invalid.
errInvalidVerified = errors.New("invalid verified")
// errJWTExpiryKey is used to check if the token is expired.
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
// ErrSignJWT indicates an error in signing jwt token.
ErrSignJWT = errors.New("failed to sign jwt token")
// ErrValidateJWTToken indicates a failure to validate JWT token.
ErrValidateJWTToken = errors.New("failed to validate jwt token")
// ErrJSONHandle indicates an error in handling JSON.
ErrJSONHandle = errors.New("failed to perform operation JSON")
)
const (
issuerName = "supermq.auth"
tokenType = "type"
RoleField = "role"
VerifiedField = "verified"
patPrefix = "pat"
)
type tokenizer struct {
keyManager auth.KeyManager
}
var _ auth.Tokenizer = (*tokenizer)(nil)
// New instantiates an implementation of Tokenizer service.
func New(keyManager auth.KeyManager) auth.Tokenizer {
return &tokenizer{
keyManager: keyManager,
}
}
func (tok *tokenizer) Issue(key auth.Key) (string, error) {
builder := jwt.NewBuilder()
builder.
Issuer(issuerName).
IssuedAt(key.IssuedAt).
Claim(tokenType, key.Type).
Expiration(key.ExpiresAt)
builder.Claim(RoleField, key.Role)
builder.Claim(VerifiedField, key.Verified)
if key.Subject != "" {
builder.Subject(key.Subject)
}
if key.ID != "" {
builder.JwtID(key.ID)
}
tkn, err := builder.Build()
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
signedTkn, err := tok.keyManager.SignJWT(tkn)
if err != nil {
return "", errors.Wrap(ErrSignJWT, err)
}
return string(signedTkn), nil
}
func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) {
if len(token) >= 3 && token[:3] == patPrefix {
return auth.Key{Type: auth.PersonalAccessToken}, nil
}
tkn, err := tok.validateToken(token)
if err != nil {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
key, err := ToKey(tkn)
if err != nil {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
return key, nil
}
func (tok *tokenizer) validateToken(token string) (jwt.Token, error) {
tkn, err := tok.keyManager.ParseJWT(token)
if err != nil {
if errors.Contains(err, errJWTExpiryKey) {
return nil, auth.ErrExpiry
}
return nil, err
}
validator := jwt.ValidatorFunc(func(_ context.Context, t jwt.Token) jwt.ValidationError {
if t.Issuer() != issuerName {
return jwt.NewValidationError(errInvalidIssuer)
}
return nil
})
if err := jwt.Validate(tkn, jwt.WithValidator(validator)); err != nil {
return nil, errors.Wrap(ErrValidateJWTToken, err)
}
return tkn, nil
}
func (tok *tokenizer) RetrieveJWKS() []auth.JWK {
return tok.keyManager.PublicJWKS()
}
func ToKey(tkn jwt.Token) (auth.Key, error) {
data, err := json.Marshal(tkn.PrivateClaims())
if err != nil {
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
}
var key auth.Key
if err := json.Unmarshal(data, &key); err != nil {
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
}
tType, ok := tkn.Get(tokenType)
if !ok {
return auth.Key{}, errInvalidType
}
kType, ok := tType.(float64)
if !ok {
return auth.Key{}, errInvalidType
}
kt := auth.KeyType(kType)
if !kt.Validate() {
return auth.Key{}, errInvalidType
}
tRole, ok := tkn.Get(RoleField)
if !ok {
return auth.Key{}, errInvalidRole
}
kRole, ok := tRole.(float64)
if !ok {
return auth.Key{}, errInvalidRole
}
tVerified, ok := tkn.Get(VerifiedField)
if !ok {
return auth.Key{}, errInvalidVerified
}
kVerified, ok := tVerified.(bool)
if !ok {
return auth.Key{}, errInvalidVerified
}
kr := auth.Role(kRole)
if !kr.Validate() {
return auth.Key{}, errInvalidRole
}
key.ID = tkn.JwtID()
key.Type = auth.KeyType(kType)
key.Role = auth.Role(kRole)
key.Issuer = tkn.Issuer()
key.Subject = tkn.Subject()
key.IssuedAt = tkn.IssuedAt()
key.ExpiresAt = tkn.Expiration()
key.Verified = kVerified
return key, nil
}
+35 -25
View File
@@ -4,47 +4,57 @@
package auth
import (
"context"
"errors"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)
var (
ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm")
ErrInvalidSymmetricKey = errors.New("invalid symmetric key")
ErrPublicKeysNotSupported = errors.New("public keys not supported for symmetric algorithm")
)
// JWK represents a JSON Web Key.
type JWK struct {
key jwk.Key
// PublicKeyInfo represents a public key for external distribution via JWKS.
// This follows RFC 7517 (JSON Web Key) specification.
type PublicKeyInfo struct {
KeyID string `json:"kid"`
KeyType string `json:"kty"`
Algorithm string `json:"alg"`
Use string `json:"use,omitempty"`
// EdDSA (Ed25519) fields
Curve string `json:"crv,omitempty"`
X string `json:"x,omitempty"`
// Future: RSA fields (n, e), ECDSA fields (x, y, crv), etc.
}
// NewJWK creates a new JWK from a jwk.Key.
func NewJWK(key jwk.Key) JWK {
return JWK{key: key}
}
// Key returns the underlying jwk.Key.
func (j JWK) Key() jwk.Key {
return j.key
}
// KeyManager represents a manager for JWT keys.
type KeyManager interface {
SignJWT(token jwt.Token) ([]byte, error)
ParseJWT(token string) (jwt.Token, error)
PublicJWKS() []JWK
// Tokenizer handles token creation and verification for authentication.
// Implementations manage underlying cryptographic operations and key distribution.
type Tokenizer interface {
// Issue creates a signed token string from the given key claims.
Issue(key Key) (token string, err error)
// Parse verifies and parses a token string (JWT or PAT), returning the extracted claims.
// For PAT tokens (prefix "pat"), returns a Key with Type set to PersonalAccessToken.
// For JWT tokens, performs cryptographic verification and returns the parsed claims.
Parse(ctx context.Context, token string) (key Key, err error)
// RetrieveJWKS returns public keys for distribution via JWKS endpoint.
// Returns ErrPublicKeysNotSupported for symmetric tokenizers (HMAC).
RetrieveJWKS() ([]PublicKeyInfo, error)
}
// IsSymmetricAlgorithm determines if the given algorithm is symmetric (HMAC-based).
// Returns true for HMAC algorithms (HS256, HS384, HS512).
// Returns false for asymmetric algorithms (EdDSA).
// Returns error for unsupported algorithms.
func IsSymmetricAlgorithm(alg string) (bool, error) {
switch alg {
case "HS256", "HS384", "HS512":
return true, nil
case "EdDSA":
return false, nil
case "HS256", "HS384", "HS512":
return true, nil
default:
return false, ErrUnsupportedKeyAlgorithm
}
-132
View File
@@ -1,132 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package asymmetric
import (
"crypto/ed25519"
"crypto/x509"
"encoding/pem"
"errors"
"os"
"github.com/absmach/supermq"
"github.com/absmach/supermq/auth"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
)
var (
errLoadingPrivateKey = errors.New("failed to load private key")
errInvalidKeySize = errors.New("invalid ED25519 key size")
errParsingPrivateKey = errors.New("failed to parse private key")
errInvalidKeyType = errors.New("private key is not ED25519")
errGeneratingKID = errors.New("failed to generate key ID")
)
type manager struct {
privateKey jwk.Key
publicKey jwk.Key
kid string
}
var _ auth.KeyManager = (*manager)(nil)
// NewKeyManager creates a new asymmetric key manager that loads the private key from a file.
func NewKeyManager(privateKeyPath string, idProvider supermq.IDProvider) (auth.KeyManager, error) {
kid, err := idProvider.ID()
if err != nil {
return nil, errors.Join(errGeneratingKID, err)
}
privateJwk, publicJwk, err := loadKeyPair(privateKeyPath, kid)
if err != nil {
return nil, err
}
return &manager{
privateKey: privateJwk,
publicKey: publicJwk,
kid: kid,
}, nil
}
func (km *manager) SignJWT(token jwt.Token) ([]byte, error) {
return jwt.Sign(token, jwt.WithKey(jwa.EdDSA, km.privateKey))
}
func (km *manager) ParseJWT(token string) (jwt.Token, error) {
set := jwk.NewSet()
if err := set.AddKey(km.publicKey); err != nil {
return nil, err
}
tkn, err := jwt.Parse(
[]byte(token),
jwt.WithValidate(true),
jwt.WithKeySet(set, jws.WithInferAlgorithmFromKey(true)),
)
if err != nil {
return nil, err
}
return tkn, nil
}
func (km *manager) PublicJWKS() []auth.JWK {
return []auth.JWK{auth.NewJWK(km.publicKey)}
}
func loadKeyPair(privateKeyPath string, kid string) (jwk.Key, jwk.Key, error) {
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
return nil, nil, errors.Join(errLoadingPrivateKey, err)
}
var privateKey ed25519.PrivateKey
block, _ := pem.Decode(privateKeyBytes)
switch {
case block != nil:
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, nil, errors.Join(errParsingPrivateKey, err)
}
var ok bool
privateKey, ok = parsedKey.(ed25519.PrivateKey)
if !ok {
return nil, nil, errInvalidKeyType
}
default:
if len(privateKeyBytes) != ed25519.PrivateKeySize {
return nil, nil, errInvalidKeySize
}
privateKey = ed25519.PrivateKey(privateKeyBytes)
}
publicKey := privateKey.Public().(ed25519.PublicKey)
privateJwk, err := jwk.FromRaw(privateKey)
if err != nil {
return nil, nil, err
}
if err := privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
return nil, nil, err
}
if err := privateJwk.Set(jwk.KeyIDKey, kid); err != nil {
return nil, nil, err
}
publicJwk, err := jwk.FromRaw(publicKey)
if err != nil {
return nil, nil, err
}
if err := publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
return nil, nil, err
}
if err := publicJwk.Set(jwk.KeyIDKey, kid); err != nil {
return nil, nil, err
}
return privateJwk, publicJwk, nil
}
-47
View File
@@ -1,47 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package symmetric
import (
"github.com/absmach/supermq/auth"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)
type manager struct {
algorithm jwa.KeyAlgorithm
secret []byte
}
var _ auth.KeyManager = (*manager)(nil)
func NewKeyManager(algorithm string, secret []byte) (auth.KeyManager, error) {
alg := jwa.KeyAlgorithmFrom(algorithm)
if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok {
return nil, auth.ErrUnsupportedKeyAlgorithm
}
if len(secret) == 0 {
return nil, auth.ErrInvalidSymmetricKey
}
return &manager{
secret: secret,
algorithm: alg,
}, nil
}
func (km *manager) SignJWT(token jwt.Token) ([]byte, error) {
return jwt.Sign(token, jwt.WithKey(km.algorithm, km.secret))
}
func (km *manager) ParseJWT(token string) (jwt.Token, error) {
return jwt.Parse(
[]byte(token),
jwt.WithValidate(true),
jwt.WithKey(km.algorithm, km.secret),
)
}
func (km *manager) PublicJWKS() []auth.JWK {
return nil
}
+1 -1
View File
@@ -100,7 +100,7 @@ func (lm *loggingMiddleware) Identify(ctx context.Context, token string) (id aut
return lm.svc.Identify(ctx, token)
}
func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.JWK) {
func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.PublicKeyInfo) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
+1 -1
View File
@@ -67,7 +67,7 @@ func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (auth.K
return ms.svc.Identify(ctx, token)
}
func (ms *metricsMiddleware) RetrieveJWKS() []auth.JWK {
func (ms *metricsMiddleware) RetrieveJWKS() []auth.PublicKeyInfo {
defer func(begin time.Time) {
ms.counter.With("method", "retrieve_jwks").Add(1)
ms.latency.With("method", "retrieve_jwks").Observe(time.Since(begin).Seconds())
+1 -1
View File
@@ -61,7 +61,7 @@ func (tm *tracingMiddleware) Identify(ctx context.Context, token string) (auth.K
return tm.svc.Identify(ctx, token)
}
func (tm *tracingMiddleware) RetrieveJWKS() []auth.JWK {
func (tm *tracingMiddleware) RetrieveJWKS() []auth.PublicKeyInfo {
return tm.svc.RetrieveJWKS()
}
-212
View File
@@ -1,212 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"github.com/absmach/supermq/auth"
"github.com/lestrrat-go/jwx/v2/jwt"
mock "github.com/stretchr/testify/mock"
)
// NewKeyManager creates a new instance of KeyManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewKeyManager(t interface {
mock.TestingT
Cleanup(func())
}) *KeyManager {
mock := &KeyManager{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// KeyManager is an autogenerated mock type for the KeyManager type
type KeyManager struct {
mock.Mock
}
type KeyManager_Expecter struct {
mock *mock.Mock
}
func (_m *KeyManager) EXPECT() *KeyManager_Expecter {
return &KeyManager_Expecter{mock: &_m.Mock}
}
// ParseJWT provides a mock function for the type KeyManager
func (_mock *KeyManager) ParseJWT(token string) (jwt.Token, error) {
ret := _mock.Called(token)
if len(ret) == 0 {
panic("no return value specified for ParseJWT")
}
var r0 jwt.Token
var r1 error
if returnFunc, ok := ret.Get(0).(func(string) (jwt.Token, error)); ok {
return returnFunc(token)
}
if returnFunc, ok := ret.Get(0).(func(string) jwt.Token); ok {
r0 = returnFunc(token)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(jwt.Token)
}
}
if returnFunc, ok := ret.Get(1).(func(string) error); ok {
r1 = returnFunc(token)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// KeyManager_ParseJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ParseJWT'
type KeyManager_ParseJWT_Call struct {
*mock.Call
}
// ParseJWT is a helper method to define mock.On call
// - token string
func (_e *KeyManager_Expecter) ParseJWT(token interface{}) *KeyManager_ParseJWT_Call {
return &KeyManager_ParseJWT_Call{Call: _e.mock.On("ParseJWT", token)}
}
func (_c *KeyManager_ParseJWT_Call) Run(run func(token string)) *KeyManager_ParseJWT_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 string
if args[0] != nil {
arg0 = args[0].(string)
}
run(
arg0,
)
})
return _c
}
func (_c *KeyManager_ParseJWT_Call) Return(token1 jwt.Token, err error) *KeyManager_ParseJWT_Call {
_c.Call.Return(token1, err)
return _c
}
func (_c *KeyManager_ParseJWT_Call) RunAndReturn(run func(token string) (jwt.Token, error)) *KeyManager_ParseJWT_Call {
_c.Call.Return(run)
return _c
}
// PublicJWKS provides a mock function for the type KeyManager
func (_mock *KeyManager) PublicJWKS() []auth.JWK {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for PublicJWKS")
}
var r0 []auth.JWK
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
r0 = returnFunc()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]auth.JWK)
}
}
return r0
}
// KeyManager_PublicJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublicJWKS'
type KeyManager_PublicJWKS_Call struct {
*mock.Call
}
// PublicJWKS is a helper method to define mock.On call
func (_e *KeyManager_Expecter) PublicJWKS() *KeyManager_PublicJWKS_Call {
return &KeyManager_PublicJWKS_Call{Call: _e.mock.On("PublicJWKS")}
}
func (_c *KeyManager_PublicJWKS_Call) Run(run func()) *KeyManager_PublicJWKS_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *KeyManager_PublicJWKS_Call) Return(jWKs []auth.JWK) *KeyManager_PublicJWKS_Call {
_c.Call.Return(jWKs)
return _c
}
func (_c *KeyManager_PublicJWKS_Call) RunAndReturn(run func() []auth.JWK) *KeyManager_PublicJWKS_Call {
_c.Call.Return(run)
return _c
}
// SignJWT provides a mock function for the type KeyManager
func (_mock *KeyManager) SignJWT(token jwt.Token) ([]byte, error) {
ret := _mock.Called(token)
if len(ret) == 0 {
panic("no return value specified for SignJWT")
}
var r0 []byte
var r1 error
if returnFunc, ok := ret.Get(0).(func(jwt.Token) ([]byte, error)); ok {
return returnFunc(token)
}
if returnFunc, ok := ret.Get(0).(func(jwt.Token) []byte); ok {
r0 = returnFunc(token)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
if returnFunc, ok := ret.Get(1).(func(jwt.Token) error); ok {
r1 = returnFunc(token)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// KeyManager_SignJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignJWT'
type KeyManager_SignJWT_Call struct {
*mock.Call
}
// SignJWT is a helper method to define mock.On call
// - token jwt.Token
func (_e *KeyManager_Expecter) SignJWT(token interface{}) *KeyManager_SignJWT_Call {
return &KeyManager_SignJWT_Call{Call: _e.mock.On("SignJWT", token)}
}
func (_c *KeyManager_SignJWT_Call) Run(run func(token jwt.Token)) *KeyManager_SignJWT_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 jwt.Token
if args[0] != nil {
arg0 = args[0].(jwt.Token)
}
run(
arg0,
)
})
return _c
}
func (_c *KeyManager_SignJWT_Call) Return(bytes []byte, err error) *KeyManager_SignJWT_Call {
_c.Call.Return(bytes, err)
return _c
}
func (_c *KeyManager_SignJWT_Call) RunAndReturn(run func(token jwt.Token) ([]byte, error)) *KeyManager_SignJWT_Call {
_c.Call.Return(run)
return _c
}
+7 -7
View File
@@ -1029,19 +1029,19 @@ func (_c *Service_ResetPATSecret_Call) RunAndReturn(run func(ctx context.Context
}
// RetrieveJWKS provides a mock function for the type Service
func (_mock *Service) RetrieveJWKS() []auth.JWK {
func (_mock *Service) RetrieveJWKS() []auth.PublicKeyInfo {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for RetrieveJWKS")
}
var r0 []auth.JWK
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
var r0 []auth.PublicKeyInfo
if returnFunc, ok := ret.Get(0).(func() []auth.PublicKeyInfo); ok {
r0 = returnFunc()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]auth.JWK)
r0 = ret.Get(0).([]auth.PublicKeyInfo)
}
}
return r0
@@ -1064,12 +1064,12 @@ func (_c *Service_RetrieveJWKS_Call) Run(run func()) *Service_RetrieveJWKS_Call
return _c
}
func (_c *Service_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Service_RetrieveJWKS_Call {
_c.Call.Return(jWKs)
func (_c *Service_RetrieveJWKS_Call) Return(publicKeyInfos []auth.PublicKeyInfo) *Service_RetrieveJWKS_Call {
_c.Call.Return(publicKeyInfos)
return _c
}
func (_c *Service_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Service_RetrieveJWKS_Call {
func (_c *Service_RetrieveJWKS_Call) RunAndReturn(run func() []auth.PublicKeyInfo) *Service_RetrieveJWKS_Call {
_c.Call.Return(run)
return _c
}
+17 -8
View File
@@ -169,22 +169,31 @@ func (_c *Tokenizer_Parse_Call) RunAndReturn(run func(ctx context.Context, token
}
// RetrieveJWKS provides a mock function for the type Tokenizer
func (_mock *Tokenizer) RetrieveJWKS() []auth.JWK {
func (_mock *Tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for RetrieveJWKS")
}
var r0 []auth.JWK
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
var r0 []auth.PublicKeyInfo
var r1 error
if returnFunc, ok := ret.Get(0).(func() ([]auth.PublicKeyInfo, error)); ok {
return returnFunc()
}
if returnFunc, ok := ret.Get(0).(func() []auth.PublicKeyInfo); ok {
r0 = returnFunc()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]auth.JWK)
r0 = ret.Get(0).([]auth.PublicKeyInfo)
}
}
return r0
if returnFunc, ok := ret.Get(1).(func() error); ok {
r1 = returnFunc()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Tokenizer_RetrieveJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveJWKS'
@@ -204,12 +213,12 @@ func (_c *Tokenizer_RetrieveJWKS_Call) Run(run func()) *Tokenizer_RetrieveJWKS_C
return _c
}
func (_c *Tokenizer_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Tokenizer_RetrieveJWKS_Call {
_c.Call.Return(jWKs)
func (_c *Tokenizer_RetrieveJWKS_Call) Return(publicKeyInfos []auth.PublicKeyInfo, err error) *Tokenizer_RetrieveJWKS_Call {
_c.Call.Return(publicKeyInfos, err)
return _c
}
func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Tokenizer_RetrieveJWKS_Call {
func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() ([]auth.PublicKeyInfo, error)) *Tokenizer_RetrieveJWKS_Call {
_c.Call.Return(run)
return _c
}
+2
View File
@@ -186,6 +186,8 @@ func ParseEntityType(et string) (EntityType, error) {
return UsersType, nil
case DashboardsStr:
return DashboardType, nil
case MessagesStr:
return MessagesType, nil
default:
return 0, fmt.Errorf("unknown domain entity type %s", et)
}
+12 -4
View File
@@ -12,6 +12,7 @@ import (
"github.com/absmach/supermq"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/google/uuid"
@@ -84,8 +85,8 @@ type Authn interface {
// other reason, non-nil error value is returned in response.
Identify(ctx context.Context, token string) (Key, error)
// RetrieveJWKS retrieves a JWKs to validate issued tokens.
RetrieveJWKS() []JWK
// RetrieveJWKS retrieves public keys to validate issued tokens.
RetrieveJWKS() []PublicKeyInfo
}
// Service specifies an API that must be fulfilled by the domain service
@@ -201,8 +202,12 @@ func (svc service) Identify(ctx context.Context, token string) (Key, error) {
}
}
func (svc service) RetrieveJWKS() []JWK {
return svc.tokenizer.RetrieveJWKS()
func (svc service) RetrieveJWKS() []PublicKeyInfo {
keys, err := svc.tokenizer.RetrieveJWKS()
if err != nil {
return nil
}
return keys
}
func (svc service) Authorize(ctx context.Context, pr policies.Policy) error {
@@ -803,6 +808,9 @@ func (svc service) authnAuthzUserPAT(ctx context.Context, token, patID string) (
_, err = svc.pats.Retrieve(ctx, key.Subject, patID)
if err != nil {
if errors.Contains(err, repoerr.ErrNotFound) {
return Key{}, svcerr.ErrNotFound
}
return Key{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
-20
View File
@@ -1,20 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
)
// Tokenizer specifies API for encoding and decoding between string and Key.
type Tokenizer interface {
// Issue converts API Key to its string representation.
Issue(key Key) (token string, err error)
// Parse extracts API Key data from string token.
Parse(ctx context.Context, token string) (key Key, err error)
// RetrieveJWKS returns the JSON Web Key Set.
RetrieveJWKS() []JWK
}
+164
View File
@@ -0,0 +1,164 @@
# Asymmetric Tokenizer
EdDSA (Ed25519) tokenizer with support for zero-downtime key rotation.
## Features
- **Single-key mode** - Simple setup with one active key
- **Two-key mode** - Active + retiring keys for _zero-downtime rotation_
- **JWKS endpoint** - Publishes all valid public keys for token verification
## Configuration
The tokenizer uses environment variables to specify key file paths:
| Environment Variable | Required | Description |
| --------------------------------- | -------- | ------------------------------------------------ |
| `SMQ_AUTH_KEYS_ACTIVE_KEY_PATH` | Yes | Path to active private key file |
| `SMQ_AUTH_KEYS_RETIRING_KEY_PATH` | No | Path to retiring private key file (for rotation) |
Please note that key names are used as **key IDs (kid)**.
### Single-Key Mode
Set only the active key path:
```bash
export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/private.key"
```
The tokenizer will:
- Issue new tokens signed with the active key
- Verify tokens using the active key
- Return one public key in JWKS endpoint
### Two-Key Mode (Key Rotation)
Set both active and retiring key paths:
```bash
export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key"
export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key"
```
The tokenizer will:
- Issue new tokens signed with the active key
- Verify tokens using both active and retiring keys
- Return both public keys in JWKS endpoint
## Key Rotation Process
Zero-downtime key rotation in 3 simple steps:
### 1. Generate New Key
```bash
openssl genpkey -algorithm Ed25519 -out keys/new.key
```
### 2. Update Environment & Restart
Move the current active key to retiring position and set the new key as active:
```bash
# Before rotation
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/current.key"
SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" # No retiring key
# During rotation (both keys active for grace period)
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key"
SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/current.key"
# After rotation (restart service with new config)
docker-compose restart auth
```
During the grace period, tokens signed with either key remain valid.
### 3. Clean Up After Grace Period
After the grace period expires (typically 7-30 days), remove the retiring key:
```bash
# Remove retiring key configuration
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key"
SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" # Remove retiring key
# Restart service
docker-compose restart auth
# Delete old key file
rm keys/current.key
```
## Grace Period Recommendations
**Recommended:** 168 hours (7 days)
**Minimum:** 24 hours
**Maximum:** 720 hours (30 days)
The grace period should be longer than your longest-lived access token duration.
## Security Best Practices
- Store private keys with `0600` permissions
- Use cryptographically secure key generation:
```bash
openssl genpkey -algorithm Ed25519 -out private.key
chmod 600 private.key
```
- Rotate keys regularly:
- Standard environments: every 90 days
- High-security environments: every 30 days
- Never commit keys to version control
- Use secrets management in production (HashiCorp Vault, AWS Secrets Manager, etc.)
## Example: Complete Rotation
```bash
# Day 0: Normal operation
export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2024.pem"
export SMQ_AUTH_KEYS_RETIRING_KEY_PATH=""
# Day 1: Start rotation - generate new key
openssl genpkey -algorithm Ed25519 -out ./keys/key-2025.pem
chmod 600 ./keys/key-2025.pem
# Day 1: Update config and restart
export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2025.pem"
export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/key-2024.pem"
docker-compose restart auth
# Day 8: Grace period expired - remove old key
export SMQ_AUTH_KEYS_RETIRING_KEY_PATH=""
docker-compose restart auth
rm ./keys/key-2024.pem
```
## Troubleshooting
### Active key not found
```
Error: active key file not found: ./keys/active.key
```
**Solution:** Ensure the file exists and path is correct. Verify `SMQ_AUTH_KEYS_ACTIVE_KEY_PATH` environment variable.
### Retiring key warning
If the retiring key path is set but the file is missing or invalid, the tokenizer logs a warning but continues with only the active key:
```
WARN: failed to load retiring key, continuing without it
```
This is by design - a missing retiring key won't prevent startup.
### Invalid key format
```
Error: failed to parse private key
```
**Solution:** Ensure you're using Ed25519 keys in PEM format (PKCS8).
+160
View File
@@ -0,0 +1,160 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package asymmetric_test
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/auth/tokenizer/asymmetric"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type incrementingIDProvider struct {
counter int
}
func (p *incrementingIDProvider) ID() (string, error) {
p.counter++
return fmt.Sprintf("key-id-%d", p.counter), nil
}
func TestTwoKeyRotation(t *testing.T) {
tmpDir := t.TempDir()
_, activePriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
_, retiringPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
activeKeyPath := filepath.Join(tmpDir, "active.key")
retiringKeyPath := filepath.Join(tmpDir, "retiring.key")
saveKey(t, activePriv, activeKeyPath)
saveKey(t, retiringPriv, retiringKeyPath)
idProvider := &incrementingIDProvider{}
tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, newTestLogger())
require.NoError(t, err)
testKey := auth.Key{
ID: "test-key",
Type: auth.AccessKey,
Subject: "user-123",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
Verified: true,
}
token, err := tokenizer.Issue(testKey)
require.NoError(t, err)
assert.NotEmpty(t, token)
verified, err := tokenizer.Parse(context.Background(), token)
require.NoError(t, err, "Should work with active token")
assert.Equal(t, testKey.Subject, verified.Subject)
publicKeys, err := tokenizer.RetrieveJWKS()
require.NoError(t, err)
assert.Len(t, publicKeys, 2, "Should return both active and retiring keys")
keyIDs := make(map[string]bool)
for _, pk := range publicKeys {
keyIDs[pk.KeyID] = true
}
assert.Len(t, keyIDs, 2, "Both keys should have unique IDs")
}
func TestSingleKeyMode(t *testing.T) {
tmpDir := t.TempDir()
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyPath := filepath.Join(tmpDir, "single.key")
saveKey(t, privateKey, keyPath)
idProvider := &mockIDProvider{id: "single-id"}
tokenizer, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
require.NoError(t, err)
testKey := auth.Key{
ID: "test",
Type: auth.AccessKey,
Subject: "user",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
}
token, err := tokenizer.Issue(testKey)
require.NoError(t, err)
_, err = tokenizer.Parse(context.Background(), token)
require.NoError(t, err)
publicKeys, err := tokenizer.RetrieveJWKS()
require.NoError(t, err, "Should return one active key")
assert.Len(t, publicKeys, 1, "Should return only the active key")
}
func TestMissingRetiringKey(t *testing.T) {
tmpDir := t.TempDir()
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
activeKeyPath := filepath.Join(tmpDir, "active.key")
saveKey(t, privateKey, activeKeyPath)
retiringKeyPath := filepath.Join(tmpDir, "nonexistent.key")
idProvider := &mockIDProvider{id: "test-id"}
tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, newTestLogger())
require.NoError(t, err, "Should succeed even if retiring key is missing")
testKey := auth.Key{
ID: "test",
Type: auth.AccessKey,
Subject: "user",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
}
token, err := tokenizer.Issue(testKey)
require.NoError(t, err)
_, err = tokenizer.Parse(context.Background(), token)
require.NoError(t, err)
publicKeys, err := tokenizer.RetrieveJWKS()
require.NoError(t, err)
assert.Len(t, publicKeys, 1, "Should return only active key when retiring key is missing")
}
func saveKey(t *testing.T, privateKey ed25519.PrivateKey, path string) {
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Key,
}
err = os.WriteFile(path, pem.EncodeToMemory(pemBlock), 0o600)
require.NoError(t, err)
}
+246
View File
@@ -0,0 +1,246 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package asymmetric
import (
"context"
"crypto/ed25519"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"log/slog"
"os"
"path/filepath"
"strings"
"github.com/absmach/supermq"
"github.com/absmach/supermq/auth"
smqjwt "github.com/absmach/supermq/auth/tokenizer/util"
"github.com/absmach/supermq/pkg/errors"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
)
const patPrefix = "pat"
var (
errLoadingPrivateKey = errors.New("failed to load private key")
errDuplicateRetiringKeyID = errors.New("retiring key ID matches active key ID")
errInvalidKeySize = errors.New("invalid ED25519 key size")
errParsingPrivateKey = errors.New("failed to parse private key")
errInvalidKeyType = errors.New("private key is not ED25519")
errNoValidPublicKeys = errors.New("no valid public keys available")
errNoActiveKey = errors.New("active key not loaded")
)
type keyPair struct {
id string
privateKey jwk.Key
publicKey jwk.Key
}
// Tokenizer is safe for concurrent use. Keys are set during construction
// and never modified afterward.
type tokenizer struct {
activeKey *keyPair
retiringKey *keyPair // Optional, for key rotation grace period
}
var _ auth.Tokenizer = (*tokenizer)(nil)
// NewTokenizer creates a new asymmetric tokenizer with active and optionally retiring keys.
// activeKeyPath is required. retiringKeyPath is optional (can be empty string).
// If retiringKeyPath is provided but the file doesn't exist or is invalid, a warning is logged
// but the tokenizer is still created with just the active key.
// Key IDs are derived from filenames to ensure consistency across multiple service instances.
func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDProvider, logger *slog.Logger) (auth.Tokenizer, error) {
activeKID := keyIDFromPath(activeKeyPath)
activePrivateJwk, activePublicJwk, err := loadKeyPair(activeKeyPath, activeKID)
if err != nil {
return nil, err
}
mgr := &tokenizer{
activeKey: &keyPair{
id: activeKID,
privateKey: activePrivateJwk,
publicKey: activePublicJwk,
},
}
if retiringKeyPath != "" {
retiringKID := keyIDFromPath(retiringKeyPath)
if retiringKID == activeKID {
return nil, errDuplicateRetiringKeyID
}
retiringPrivateJwk, retiringPublicJwk, err := loadKeyPair(retiringKeyPath, retiringKID)
if err != nil {
logger.Warn("failed to load retiring key, continuing without it", slog.Any("error", err))
return mgr, nil
}
mgr.retiringKey = &keyPair{
id: retiringKID,
privateKey: retiringPrivateJwk,
publicKey: retiringPublicJwk,
}
logger.Info("loaded retiring key for rotation grace period", slog.String("key_id", retiringKID))
}
return mgr, nil
}
func (km *tokenizer) Issue(key auth.Key) (string, error) {
if km.activeKey == nil {
return "", errNoActiveKey
}
tkn, err := smqjwt.BuildToken(key)
if err != nil {
return "", err
}
headers := jws.NewHeaders()
if err := headers.Set(jwk.KeyIDKey, km.activeKey.id); err != nil {
return "", err
}
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, km.activeKey.privateKey, jws.WithProtectedHeaders(headers)))
if err != nil {
return "", err
}
return string(signedBytes), nil
}
func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
if len(tokenString) >= 3 && tokenString[:3] == patPrefix {
return auth.Key{Type: auth.PersonalAccessToken}, nil
}
set := jwk.NewSet()
if err := set.AddKey(km.activeKey.publicKey); err != nil {
return auth.Key{}, err
}
if km.retiringKey != nil {
if err := set.AddKey(km.retiringKey.publicKey); err != nil {
return auth.Key{}, err
}
}
tkn, err := jwt.Parse(
[]byte(tokenString),
jwt.WithValidate(true),
jwt.WithKeySet(set, jws.WithInferAlgorithmFromKey(true)),
)
if err != nil {
return auth.Key{}, err
}
if tkn.Issuer() != smqjwt.IssuerName {
return auth.Key{}, smqjwt.ErrInvalidIssuer
}
return smqjwt.ToKey(tkn)
}
func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
publicKeys := make([]auth.PublicKeyInfo, 0, 2)
if km.activeKey != nil {
if pkInfo := extractPublicKeyInfo(km.activeKey); pkInfo != nil {
publicKeys = append(publicKeys, *pkInfo)
}
}
if km.retiringKey != nil {
if pkInfo := extractPublicKeyInfo(km.retiringKey); pkInfo != nil {
publicKeys = append(publicKeys, *pkInfo)
}
}
if len(publicKeys) == 0 {
return nil, errNoValidPublicKeys
}
return publicKeys, nil
}
func extractPublicKeyInfo(kp *keyPair) *auth.PublicKeyInfo {
var rawKey ed25519.PublicKey
if err := kp.publicKey.Raw(&rawKey); err != nil {
return nil
}
return &auth.PublicKeyInfo{
KeyID: kp.id,
KeyType: "OKP",
Algorithm: "EdDSA",
Use: "sig",
Curve: "Ed25519",
X: base64.RawURLEncoding.EncodeToString(rawKey),
}
}
func loadKeyPair(privateKeyPath string, kid string) (jwk.Key, jwk.Key, error) {
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
return nil, nil, errors.Wrap(errLoadingPrivateKey, err)
}
var privateKey ed25519.PrivateKey
block, _ := pem.Decode(privateKeyBytes)
switch {
case block != nil:
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, nil, errors.Wrap(errParsingPrivateKey, err)
}
var ok bool
privateKey, ok = parsedKey.(ed25519.PrivateKey)
if !ok {
return nil, nil, errInvalidKeyType
}
default:
if len(privateKeyBytes) != ed25519.PrivateKeySize {
return nil, nil, errInvalidKeySize
}
privateKey = ed25519.PrivateKey(privateKeyBytes)
}
publicKey := privateKey.Public().(ed25519.PublicKey)
privateJwk, err := jwk.FromRaw(privateKey)
if err != nil {
return nil, nil, err
}
if err := privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
return nil, nil, err
}
if err := privateJwk.Set(jwk.KeyIDKey, kid); err != nil {
return nil, nil, err
}
publicJwk, err := jwk.FromRaw(publicKey)
if err != nil {
return nil, nil, err
}
if err := publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
return nil, nil, err
}
if err := publicJwk.Set(jwk.KeyIDKey, kid); err != nil {
return nil, nil, err
}
return privateJwk, publicJwk, nil
}
func keyIDFromPath(path string) string {
base := filepath.Base(path)
ext := filepath.Ext(base)
return strings.TrimSuffix(base, ext)
}
+418
View File
@@ -0,0 +1,418 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package asymmetric_test
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"log/slog"
"os"
"path/filepath"
"testing"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/auth/tokenizer/asymmetric"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockIDProvider struct {
id string
}
func (m *mockIDProvider) ID() (string, error) {
return m.id, nil
}
func newTestLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
}
func TestNewKeyManager(t *testing.T) {
idProvider := &mockIDProvider{id: "unused"}
tmpDir := t.TempDir()
keyPath := filepath.Join(tmpDir, "private.key")
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Key,
}
cases := []struct {
name string
setupKey func() string
expectErr bool
errContains string
}{
{
name: "valid PEM key",
setupKey: func() string {
err := os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
require.NoError(t, err)
return keyPath
},
expectErr: false,
},
{
name: "valid raw key",
setupKey: func() string {
rawKeyPath := filepath.Join(tmpDir, "raw_private.key")
err := os.WriteFile(rawKeyPath, privateKey, 0o600)
require.NoError(t, err)
return rawKeyPath
},
expectErr: false,
},
{
name: "non-existent key file",
setupKey: func() string {
return filepath.Join(tmpDir, "nonexistent.key")
},
expectErr: true,
errContains: "failed to load private key",
},
{
name: "invalid key size",
setupKey: func() string {
invalidPath := filepath.Join(tmpDir, "invalid.key")
err := os.WriteFile(invalidPath, []byte("invalid"), 0o600)
require.NoError(t, err)
return invalidPath
},
expectErr: true,
errContains: "invalid ED25519 key size",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
path := tc.setupKey()
km, err := asymmetric.NewTokenizer(path, "", idProvider, newTestLogger())
if tc.expectErr {
assert.Error(t, err)
if tc.errContains != "" {
assert.Contains(t, err.Error(), tc.errContains)
}
assert.Nil(t, km)
} else {
assert.NoError(t, err)
assert.NotNil(t, km)
}
})
}
}
func TestSign(t *testing.T) {
idProvider := &mockIDProvider{id: "unused"}
tmpDir := t.TempDir()
keyPath := filepath.Join(tmpDir, "private.key")
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Key,
}
err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
require.NoError(t, err)
km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
require.NoError(t, err)
cases := []struct {
name string
key auth.Key
}{
{
name: "sign valid key with all fields",
key: auth.Key{
ID: "key-id",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Subject: "user-id",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
Verified: true,
},
},
{
name: "sign key without subject",
key: auth.Key{
ID: "key-id",
Type: auth.APIKey,
Issuer: "supermq.auth",
Role: auth.AdminRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(24 * time.Hour).UTC(),
Verified: false,
},
},
{
name: "sign key without ID",
key: auth.Key{
Type: auth.AccessKey,
Subject: "user-id",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
token, err := km.Issue(tc.key)
assert.NoError(t, err)
assert.NotEmpty(t, token)
parts := splitJWT(token)
assert.Equal(t, 3, len(parts), "JWT should have 3 parts")
})
}
}
func TestVerify(t *testing.T) {
idProvider := &mockIDProvider{id: "unused"}
tmpDir := t.TempDir()
keyPath := filepath.Join(tmpDir, "private.key")
kid := "private"
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Key,
}
err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
require.NoError(t, err)
km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
require.NoError(t, err)
validKey := auth.Key{
ID: "key-id",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Subject: "user-id",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
Verified: true,
}
validToken, err := km.Issue(validKey)
require.NoError(t, err, "Signing a valid token should succeed")
expiredKey := validKey
expiredKey.ExpiresAt = time.Now().Add(-1 * time.Hour).UTC()
expiredToken, err := km.Issue(expiredKey)
require.NoError(t, err, "Creating an expired token should succeed")
wrongIssuerKey := validKey
wrongIssuerKey.Issuer = "wrong.issuer"
privateJwk, err := jwk.FromRaw(privateKey)
require.NoError(t, err)
require.NoError(t, privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, privateJwk.Set(jwk.KeyIDKey, kid))
builder := jwt.NewBuilder()
builder.Issuer(wrongIssuerKey.Issuer).
Subject(wrongIssuerKey.Subject).
IssuedAt(wrongIssuerKey.IssuedAt).
Expiration(wrongIssuerKey.ExpiresAt).
JwtID(wrongIssuerKey.ID).
Claim("type", wrongIssuerKey.Type).
Claim("role", wrongIssuerKey.Role).
Claim("verified", wrongIssuerKey.Verified)
wrongIssuerJWT, err := builder.Build()
require.NoError(t, err)
wrongIssuerTokenBytes, err := jwt.Sign(wrongIssuerJWT, jwt.WithKey(jwa.EdDSA, privateJwk))
require.NoError(t, err)
wrongIssuerToken := string(wrongIssuerTokenBytes)
cases := []struct {
name string
token string
expectErr bool
errContains string
}{
{
name: "verify valid token",
token: validToken,
expectErr: false,
},
{
name: "verify expired token",
token: expiredToken,
expectErr: true,
errContains: "exp",
},
{
name: "verify token with wrong issuer",
token: wrongIssuerToken,
expectErr: true,
errContains: "invalid token issuer",
},
{
name: "verify malformed token",
token: "not.a.valid.jwt",
expectErr: true,
errContains: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
key, err := km.Parse(context.Background(), tc.token)
if tc.expectErr {
assert.Error(t, err)
if tc.errContains != "" {
assert.Contains(t, err.Error(), tc.errContains)
}
} else {
assert.NoError(t, err)
assert.Equal(t, validKey.Subject, key.Subject)
assert.Equal(t, validKey.Type, key.Type)
assert.Equal(t, validKey.Role, key.Role)
}
})
}
}
func TestPublicKeys(t *testing.T) {
idProvider := &mockIDProvider{id: "unused"}
tmpDir := t.TempDir()
keyPath := filepath.Join(tmpDir, "private.key")
kid := "private"
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Key,
}
err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
require.NoError(t, err)
km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
require.NoError(t, err)
keys, err := km.RetrieveJWKS()
assert.NoError(t, err)
assert.Len(t, keys, 1)
key := keys[0]
assert.Equal(t, kid, key.KeyID)
assert.Equal(t, "OKP", key.KeyType)
assert.Equal(t, "EdDSA", key.Algorithm)
assert.Equal(t, "sig", key.Use)
assert.Equal(t, "Ed25519", key.Curve)
assert.NotEmpty(t, key.X)
decoded, err := base64.RawURLEncoding.DecodeString(key.X)
assert.NoError(t, err, "The public key should be decoded")
assert.Equal(t, publicKey, ed25519.PublicKey(decoded))
}
func TestSignAndVerifyRoundTrip(t *testing.T) {
idProvider := &mockIDProvider{id: "unused"}
tmpDir := t.TempDir()
keyPath := filepath.Join(tmpDir, "private.key")
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Key,
}
err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
require.NoError(t, err)
km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
require.NoError(t, err)
originalKey := auth.Key{
ID: "key-123",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Subject: "user-456",
Role: auth.UserRole,
IssuedAt: time.Now().UTC().Truncate(time.Second),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC().Truncate(time.Second),
Verified: true,
}
token, err := km.Issue(originalKey)
require.NoError(t, err)
verifiedKey, err := km.Parse(context.Background(), token)
require.NoError(t, err, "Verification of a valid key should succeed")
assert.Equal(t, originalKey.ID, verifiedKey.ID)
assert.Equal(t, originalKey.Type, verifiedKey.Type)
assert.Equal(t, originalKey.Subject, verifiedKey.Subject)
assert.Equal(t, originalKey.Role, verifiedKey.Role)
assert.Equal(t, originalKey.Verified, verifiedKey.Verified)
assert.WithinDuration(t, originalKey.IssuedAt, verifiedKey.IssuedAt, time.Second)
assert.WithinDuration(t, originalKey.ExpiresAt, verifiedKey.ExpiresAt, time.Second)
}
func splitJWT(token string) []string {
parts := []string{}
start := 0
for i := 0; i < len(token); i++ {
if token[i] == '.' {
parts = append(parts, token[start:i])
start = i + 1
}
}
parts = append(parts, token[start:])
return parts
}
+84
View File
@@ -0,0 +1,84 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package symmetric
import (
"context"
"github.com/absmach/supermq/auth"
smqjwt "github.com/absmach/supermq/auth/tokenizer/util"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)
const (
patPrefix = "pat"
)
var errJWTExpiryKey = errors.New(`"exp" not satisfied`)
type tokenizer struct {
algorithm jwa.KeyAlgorithm
secret []byte
}
var _ auth.Tokenizer = (*tokenizer)(nil)
func NewTokenizer(algorithm string, secret []byte) (auth.Tokenizer, error) {
alg := jwa.KeyAlgorithmFrom(algorithm)
if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok {
return nil, auth.ErrUnsupportedKeyAlgorithm
}
if len(secret) == 0 {
return nil, auth.ErrInvalidSymmetricKey
}
return &tokenizer{
secret: secret,
algorithm: alg,
}, nil
}
func (km *tokenizer) Issue(key auth.Key) (string, error) {
tkn, err := smqjwt.BuildToken(key)
if err != nil {
return "", err
}
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(km.algorithm, km.secret))
if err != nil {
return "", err
}
return string(signedBytes), nil
}
func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
if len(tokenString) >= 3 && tokenString[:3] == patPrefix {
return auth.Key{Type: auth.PersonalAccessToken}, nil
}
tkn, err := jwt.Parse(
[]byte(tokenString),
jwt.WithValidate(true),
jwt.WithKey(km.algorithm, km.secret),
)
if err != nil {
if errors.Contains(err, errJWTExpiryKey) {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, auth.ErrExpiry)
}
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if tkn.Issuer() != smqjwt.IssuerName {
return auth.Key{}, smqjwt.ErrInvalidIssuer
}
return smqjwt.ToKey(tkn)
}
func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
return nil, auth.ErrPublicKeysNotSupported
}
+362
View File
@@ -0,0 +1,362 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package symmetric_test
import (
"context"
"testing"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/auth/tokenizer/symmetric"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewTokenizer(t *testing.T) {
cases := []struct {
name string
algorithm string
secret []byte
expectErr bool
errContains string
}{
{
name: "valid HS256 algorithm",
algorithm: "HS256",
secret: []byte("my-secret-key-32-bytes-long!!"),
expectErr: false,
},
{
name: "valid HS384 algorithm",
algorithm: "HS384",
secret: []byte("my-secret-key-48-bytes-long-for-hs384-!!!!!!"),
expectErr: false,
},
{
name: "valid HS512 algorithm",
algorithm: "HS512",
secret: []byte("my-secret-key-64-bytes-long-for-hs512-algorithm-testing!!!!"),
expectErr: false,
},
{
name: "invalid algorithm",
algorithm: "INVALID_ALG",
secret: []byte("my-secret-key"),
expectErr: true,
errContains: "unsupported key algorithm",
},
{
name: "empty secret",
algorithm: "HS256",
secret: []byte{},
expectErr: true,
errContains: "invalid symmetric key",
},
{
name: "nil secret",
algorithm: "HS256",
secret: nil,
expectErr: true,
errContains: "invalid symmetric key",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
km, err := symmetric.NewTokenizer(tc.algorithm, tc.secret)
if tc.expectErr {
assert.Error(t, err)
if tc.errContains != "" {
assert.Contains(t, err.Error(), tc.errContains)
}
assert.Nil(t, km)
} else {
assert.NoError(t, err)
assert.NotNil(t, km)
}
})
}
}
func TestSign(t *testing.T) {
secret := []byte("my-super-secret-key-for-testing")
km, err := symmetric.NewTokenizer("HS256", secret)
require.NoError(t, err)
cases := []struct {
name string
key auth.Key
}{
{
name: "sign valid key with all fields",
key: auth.Key{
ID: "key-id",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Subject: "user-id",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
Verified: true,
},
},
{
name: "sign key without subject",
key: auth.Key{
ID: "key-id",
Type: auth.APIKey,
Issuer: "supermq.auth",
Role: auth.AdminRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(24 * time.Hour).UTC(),
Verified: false,
},
},
{
name: "sign key without ID",
key: auth.Key{
Type: auth.AccessKey,
Subject: "user-id",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
token, err := km.Issue(tc.key)
assert.NoError(t, err)
assert.NotEmpty(t, token)
// Verify the token is valid JWT format (3 parts separated by dots)
parts := splitJWT(token)
assert.Equal(t, 3, len(parts), "JWT should have 3 parts")
})
}
}
func TestVerify(t *testing.T) {
secret := []byte("my-super-secret-key-for-testing")
km, err := symmetric.NewTokenizer("HS256", secret)
require.NoError(t, err)
validKey := auth.Key{
ID: "key-id",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Subject: "user-id",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
Verified: true,
}
validToken, err := km.Issue(validKey)
require.NoError(t, err, "Signing valid token should succeed")
expiredKey := validKey
expiredKey.ExpiresAt = time.Now().Add(-1 * time.Hour).UTC()
expiredToken, err := km.Issue(expiredKey)
require.NoError(t, err)
wrongIssuerKey := validKey
wrongIssuerKey.Issuer = "wrong.issuer"
builder := jwt.NewBuilder()
builder.Issuer(wrongIssuerKey.Issuer).
Subject(wrongIssuerKey.Subject).
IssuedAt(wrongIssuerKey.IssuedAt).
Expiration(wrongIssuerKey.ExpiresAt).
JwtID(wrongIssuerKey.ID).
Claim("type", wrongIssuerKey.Type).
Claim("role", wrongIssuerKey.Role).
Claim("verified", wrongIssuerKey.Verified)
wrongIssuerJWT, err := builder.Build()
require.NoError(t, err)
wrongIssuerTokenBytes, err := jwt.Sign(wrongIssuerJWT, jwt.WithKey(jwa.HS256, secret))
require.NoError(t, err)
wrongIssuerToken := string(wrongIssuerTokenBytes)
wrongSecretKM, err := symmetric.NewTokenizer("HS256", []byte("different-secret-key-here"))
require.NoError(t, err)
wrongSecretToken, err := wrongSecretKM.Issue(validKey)
require.NoError(t, err)
cases := []struct {
name string
token string
expectErr bool
errContains string
}{
{
name: "verify valid token",
token: validToken,
expectErr: false,
},
{
name: "verify expired token",
token: expiredToken,
expectErr: true,
errContains: "exp",
},
{
name: "verify token with wrong issuer",
token: wrongIssuerToken,
expectErr: true,
errContains: "invalid token issuer",
},
{
name: "verify token with wrong secret",
token: wrongSecretToken,
expectErr: true,
errContains: "",
},
{
name: "verify malformed token",
token: "not.a.valid.jwt",
expectErr: true,
errContains: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
key, err := km.Parse(context.Background(), tc.token)
if tc.expectErr {
assert.Error(t, err)
if tc.errContains != "" {
assert.Contains(t, err.Error(), tc.errContains)
}
} else {
assert.NoError(t, err)
assert.Equal(t, validKey.Subject, key.Subject)
assert.Equal(t, validKey.Type, key.Type)
assert.Equal(t, validKey.Role, key.Role)
}
})
}
}
func TestPublicKeys(t *testing.T) {
secret := []byte("my-super-secret-key-for-testing")
km, err := symmetric.NewTokenizer("HS256", secret)
require.NoError(t, err)
keys, err := km.RetrieveJWKS()
assert.Error(t, err)
assert.Equal(t, auth.ErrPublicKeysNotSupported, err)
assert.Nil(t, keys)
}
func TestSignAndVerifyRoundTrip(t *testing.T) {
algorithms := []string{"HS256", "HS384", "HS512"}
for _, alg := range algorithms {
t.Run(alg, func(t *testing.T) {
secret := []byte("my-super-secret-key-for-testing-" + alg)
km, err := symmetric.NewTokenizer(alg, secret)
require.NoError(t, err)
originalKey := auth.Key{
ID: "key-123",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Subject: "user-456",
Role: auth.UserRole,
IssuedAt: time.Now().UTC().Truncate(time.Second),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC().Truncate(time.Second),
Verified: true,
}
token, err := km.Issue(originalKey)
require.NoError(t, err)
verifiedKey, err := km.Parse(context.Background(), token)
require.NoError(t, err)
assert.Equal(t, originalKey.ID, verifiedKey.ID)
assert.Equal(t, originalKey.Type, verifiedKey.Type)
assert.Equal(t, originalKey.Subject, verifiedKey.Subject)
assert.Equal(t, originalKey.Role, verifiedKey.Role)
assert.Equal(t, originalKey.Verified, verifiedKey.Verified)
assert.WithinDuration(t, originalKey.IssuedAt, verifiedKey.IssuedAt, time.Second)
assert.WithinDuration(t, originalKey.ExpiresAt, verifiedKey.ExpiresAt, time.Second)
})
}
}
func TestDifferentAlgorithms(t *testing.T) {
secret := []byte("my-super-secret-key-for-testing-algorithms")
key := auth.Key{
ID: "key-id",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Subject: "user-id",
Role: auth.UserRole,
IssuedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
Verified: true,
}
km256, err := symmetric.NewTokenizer("HS256", secret)
require.NoError(t, err)
token256, err := km256.Issue(key)
require.NoError(t, err)
km384, err := symmetric.NewTokenizer("HS384", secret)
require.NoError(t, err)
token384, err := km384.Issue(key)
require.NoError(t, err)
km512, err := symmetric.NewTokenizer("HS512", secret)
require.NoError(t, err)
token512, err := km512.Issue(key)
require.NoError(t, err)
assert.NotEqual(t, token256, token384)
assert.NotEqual(t, token256, token512)
assert.NotEqual(t, token384, token512)
_, err = km256.Parse(context.Background(), token256)
assert.NoError(t, err, "verification of km256 token should pass with km256 verifier")
_, err = km384.Parse(context.Background(), token384)
assert.NoError(t, err, "verification of km384 token should pass with km384 verifier")
_, err = km512.Parse(context.Background(), token512)
assert.NoError(t, err, "verification of km512 token should pass with km512 verifier")
_, err = km384.Parse(context.Background(), token256)
assert.Error(t, err, "Cross verification should fail")
_, err = km512.Parse(context.Background(), token256)
assert.Error(t, err)
}
func splitJWT(token string) []string {
parts := []string{}
start := 0
for i := 0; i < len(token); i++ {
if token[i] == '.' {
parts = append(parts, token[start:i])
start = i + 1
}
}
parts = append(parts, token[start:])
return parts
}
+110
View File
@@ -0,0 +1,110 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package util
import (
"encoding/json"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
"github.com/lestrrat-go/jwx/v2/jwt"
)
var (
// ErrInvalidIssuer represents an invalid token issuer value.
ErrInvalidIssuer = errors.New("invalid token issuer value")
// ErrJSONHandle indicates an error in handling JSON.
ErrJSONHandle = errors.New("failed to perform operation JSON")
errInvalidType = errors.New("invalid token type")
errInvalidRole = errors.New("invalid role")
errInvalidVerified = errors.New("invalid verified")
)
const (
IssuerName = "supermq.auth"
TokenType = "type"
RoleField = "role"
VerifiedField = "verified"
)
// ToKey converts a JWT token to an auth.Key by extracting claims.
func ToKey(tkn jwt.Token) (auth.Key, error) {
data, err := json.Marshal(tkn.PrivateClaims())
if err != nil {
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
}
var key auth.Key
if err := json.Unmarshal(data, &key); err != nil {
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
}
tType, ok := tkn.Get(TokenType)
if !ok {
return auth.Key{}, errInvalidType
}
kType, ok := tType.(float64)
if !ok {
return auth.Key{}, errInvalidType
}
kt := auth.KeyType(kType)
if !kt.Validate() {
return auth.Key{}, errInvalidType
}
tRole, ok := tkn.Get(RoleField)
if !ok {
return auth.Key{}, errInvalidRole
}
kRole, ok := tRole.(float64)
if !ok {
return auth.Key{}, errInvalidRole
}
tVerified, ok := tkn.Get(VerifiedField)
if !ok {
return auth.Key{}, errInvalidVerified
}
kVerified, ok := tVerified.(bool)
if !ok {
return auth.Key{}, errInvalidVerified
}
kr := auth.Role(kRole)
if !kr.Validate() {
return auth.Key{}, errInvalidRole
}
key.ID = tkn.JwtID()
key.Type = auth.KeyType(kType)
key.Role = auth.Role(kRole)
key.Issuer = tkn.Issuer()
key.Subject = tkn.Subject()
key.IssuedAt = tkn.IssuedAt()
key.ExpiresAt = tkn.Expiration()
key.Verified = kVerified
return key, nil
}
func BuildToken(key auth.Key) (jwt.Token, error) {
builder := jwt.NewBuilder()
builder.
Issuer(IssuerName).
IssuedAt(key.IssuedAt).
Claim(TokenType, key.Type).
Expiration(key.ExpiresAt).
Claim(RoleField, key.Role).
Claim(VerifiedField, key.Verified)
if key.Subject != "" {
builder.Subject(key.Subject)
}
if key.ID != "" {
builder.JwtID(key.ID)
}
return builder.Build()
}
+1 -2
View File
@@ -136,8 +136,7 @@ type Service interface {
// ChannelRepository specifies a channel persistence API.
type Repository interface {
// Save persists multiple channels. Channels are saved using a transaction. If one channel
// fails then none will be saved. Successful operation is indicated by non-nil
// error response.
// fails then none will be saved. Successful operation is indicated by non-nil error response.
Save(ctx context.Context, chs ...Channel) ([]Channel, error)
// Update performs an update to the existing channel.
+42 -11
View File
@@ -22,11 +22,10 @@ import (
httpapi "github.com/absmach/supermq/auth/api/http"
"github.com/absmach/supermq/auth/cache"
"github.com/absmach/supermq/auth/hasher"
"github.com/absmach/supermq/auth/jwt"
"github.com/absmach/supermq/auth/keymanager/asymmetric"
"github.com/absmach/supermq/auth/keymanager/symmetric"
"github.com/absmach/supermq/auth/middleware"
apostgres "github.com/absmach/supermq/auth/postgres"
"github.com/absmach/supermq/auth/tokenizer/asymmetric"
"github.com/absmach/supermq/auth/tokenizer/symmetric"
redisclient "github.com/absmach/supermq/internal/clients/redis"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/pkg/jaeger"
@@ -69,7 +68,8 @@ type config struct {
AccessDuration time.Duration `env:"SMQ_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"`
RefreshDuration time.Duration `env:"SMQ_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"`
KeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"EdDSA"`
PrivateKeyPath string `env:"SMQ_AUTH_KEYS_PRIVATE_KEY_PATH" envDefault:"./ssl/keys/private.pem"`
ActiveKeyPath string `env:"SMQ_AUTH_KEYS_ACTIVE_KEY_PATH" envDefault:"./keys/active.key"`
RetiringKeyPath string `env:"SMQ_AUTH_KEYS_RETIRING_KEY_PATH" envDefault:""`
InvitationDuration time.Duration `env:"SMQ_AUTH_INVITATION_DURATION" envDefault:"168h"`
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
@@ -159,17 +159,23 @@ func main() {
idProvider := uuid.New()
var keyManager auth.KeyManager
if err := validateKeyConfig(isSymmetric, cfg, logger); err != nil {
logger.Error(fmt.Sprintf("invalid key configuration: %s", err.Error()))
exitCode = 1
return
}
var tokenizer auth.Tokenizer
switch {
case isSymmetric:
keyManager, err = symmetric.NewKeyManager(cfg.KeyAlgorithm, []byte(cfg.SecretKey))
tokenizer, err = symmetric.NewTokenizer(cfg.KeyAlgorithm, []byte(cfg.SecretKey))
if err != nil {
logger.Error(fmt.Sprintf("failed to create symmetric key manager: %s", err.Error()))
exitCode = 1
return
}
default:
keyManager, err = asymmetric.NewKeyManager(cfg.PrivateKeyPath, idProvider)
tokenizer, err = asymmetric.NewTokenizer(cfg.ActiveKeyPath, cfg.RetiringKeyPath, idProvider, logger)
if err != nil {
logger.Error(fmt.Sprintf("failed to create asymmetric key manager: %s", err.Error()))
exitCode = 1
@@ -177,7 +183,7 @@ func main() {
}
}
svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration, keyManager, idProvider)
svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration, tokenizer, idProvider)
if err != nil {
logger.Error(fmt.Sprintf("failed to create service : %s\n", err.Error()))
exitCode = 1
@@ -258,7 +264,34 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch
return nil
}
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, keyManager auth.KeyManager, idProvider supermq.IDProvider) (auth.Service, error) {
func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error {
if isSymmetric {
if cfg.SecretKey == "secret" {
return fmt.Errorf("default secret key is insecure - please set SMQ_AUTH_SECRET_KEY environment variable")
}
return nil
}
// Validate active key path
_, err := os.Stat(cfg.ActiveKeyPath)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("active key file not found: %s - please set SMQ_AUTH_KEYS_ACTIVE_KEY_PATH", cfg.ActiveKeyPath)
}
return fmt.Errorf("failed to access active key file: %w", err)
}
// Retiring key is optional - only validate if path is provided
if cfg.RetiringKeyPath != "" {
if _, err := os.Stat(cfg.RetiringKeyPath); err != nil {
l.Warn("retiring key path provided but file not accessible", slog.Any("error", err))
}
}
return nil
}
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, tokenizer auth.Tokenizer, idProvider supermq.IDProvider) (auth.Service, error) {
cache := cache.NewPatsCache(cacheClient, keyDuration)
database := pgclient.NewDatabase(db, dbConfig, tracer)
@@ -269,8 +302,6 @@ func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.
pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger)
pService := spicedb.NewPolicyService(spicedbClient, logger)
tokenizer := jwt.New(keyManager)
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc = middleware.NewLogging(svc, logger)
counter, latency := prometheus.MakeMetrics("auth", "api")
+2 -1
View File
@@ -101,7 +101,8 @@ SMQ_AUTH_DB_SSL_ROOT_CERT=
SMQ_AUTH_ACCESS_TOKEN_DURATION="1h"
SMQ_AUTH_REFRESH_TOKEN_DURATION="24h"
SMQ_AUTH_KEYS_ALGORITHM="EdDSA"
SMQ_AUTH_KEYS_PRIVATE_KEY_PATH="./ssl/keys/private.key"
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key"
SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key"
SMQ_AUTH_INVITATION_DURATION="168h"
SMQ_AUTH_ADAPTER_INSTANCE_ID=
SMQ_AUTH_CACHE_URL=redis://auth-redis:${SMQ_REDIS_TCP_PORT}/0
+11 -5
View File
@@ -121,7 +121,8 @@ services:
SMQ_AUTH_ACCESS_TOKEN_DURATION: ${SMQ_AUTH_ACCESS_TOKEN_DURATION}
SMQ_AUTH_REFRESH_TOKEN_DURATION: ${SMQ_AUTH_REFRESH_TOKEN_DURATION}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_AUTH_KEYS_PRIVATE_KEY_PATH: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:+/keys/private.key}
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH: ${SMQ_AUTH_KEYS_ACTIVE_KEY_PATH:+/keys/active.key}
SMQ_AUTH_KEYS_RETIRING_KEY_PATH: ${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:+/keys/retiring.key}
## Compose supports parameter expansion in environment,
## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty
## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default
@@ -152,11 +153,16 @@ services:
volumes:
- ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE}
- supermq-pat-db-volume:/supermq-data
- supermq-auth-keys-volume:/keys
# Auth private key file
# Auth active private key file
- type: bind
source: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:-ssl/certs/dummy/private_key}
target: /keys/private.key
source: ${SMQ_AUTH_KEYS_ACTIVE_KEY_PATH}
target: /keys/active.key
read_only: true
# Auth retiring private key file (optional, for key rotation)
- type: bind
source: ${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:-ssl/certs/dummy/retiring_key}
target: /keys/retiring${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:+.key}
read_only: true
bind:
create_host_path: true
# Auth gRPC mTLS server certificates
+3
View File
@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIE9Qu5lN6KOfdO14XJUClM1UPrqT55BczLMcRuSG7Ziy
-----END PRIVATE KEY-----
+78 -34
View File
@@ -5,6 +5,7 @@ package jwks
import (
"context"
"fmt"
"io"
"net/http"
"strings"
@@ -14,7 +15,7 @@ import (
grpcAuthV1 "github.com/absmach/supermq/api/grpc/auth/v1"
smqauth "github.com/absmach/supermq/auth"
"github.com/absmach/supermq/auth/api/grpc/auth"
smqjwt "github.com/absmach/supermq/auth/jwt"
smqjwt "github.com/absmach/supermq/auth/tokenizer/util"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -26,8 +27,13 @@ import (
)
const (
issuerName = "supermq.auth"
cacheDuration = 5 * time.Minute
issuerName = "supermq.auth"
acceptHeader = "application/json"
fetchKeyDeadline = 10 * time.Second
cacheDuration = 5 * time.Minute
errorBodyBytes = 1024
)
var (
@@ -39,12 +45,6 @@ var (
errInvalidIssuer = errors.New("invalid token issuer value")
// ErrValidateJWTToken indicates a failure to validate JWT token.
errValidateJWTToken = errors.New("failed to validate jwt token")
jwksCache = struct {
sync.RWMutex
jwks jwk.Set
cachedAt time.Time
}{}
)
var _ authn.Authentication = (*authentication)(nil)
@@ -52,6 +52,14 @@ var _ authn.Authentication = (*authentication)(nil)
type authentication struct {
jwksURL string
authSvcClient grpcAuthV1.AuthServiceClient
httpClient *http.Client
cache *jwksCache
}
type jwksCache struct {
mu sync.RWMutex
jwks jwk.Set
cachedAt time.Time
}
func NewAuthentication(ctx context.Context, jwksURL string, cfg grpcclient.Config) (authn.Authentication, grpcclient.Handler, error) {
@@ -69,9 +77,13 @@ func NewAuthentication(ctx context.Context, jwksURL string, cfg grpcclient.Confi
}
authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout)
httpClient := &http.Client{}
return authentication{
jwksURL: jwksURL,
authSvcClient: authSvcClient,
httpClient: httpClient,
cache: &jwksCache{},
}, client, nil
}
@@ -84,14 +96,26 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
return authn.Session{Type: authn.PersonalAccessToken, PatID: res.GetId(), UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
}
jwks, err := a.fetchJWKS(ctx)
jwks, err := a.fetchJWKS(ctx, false)
if err != nil {
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
tkn, err := validateToken(token, jwks)
if err != nil {
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
// If signature verification failed, try force with refresh JWKS (key rotation scenario)
if isSignatureError(err) {
jwks, fetchErr := a.fetchJWKS(ctx, true)
if fetchErr == nil {
tkn, err = validateToken(token, jwks)
}
}
if err != nil {
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
}
key, err := smqjwt.ToKey(tkn)
if err != nil {
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
@@ -105,45 +129,63 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
}, nil
}
func (a authentication) fetchJWKS(ctx context.Context) (jwk.Set, error) {
jwksCache.RLock()
if time.Since(jwksCache.cachedAt) < cacheDuration && jwksCache.jwks.Len() > 0 {
cached := jwksCache.jwks
jwksCache.RUnlock()
return cached, nil
}
jwksCache.RUnlock()
func isSignatureError(err error) bool {
return !errors.Contains(err, errJWTExpiryKey) &&
!errors.Contains(err, errInvalidIssuer) &&
!errors.Contains(err, smqauth.ErrExpiry)
}
req, err := http.NewRequestWithContext(ctx, "GET", a.jwksURL, nil)
if err != nil {
return nil, err
func (a authentication) fetchJWKS(ctx context.Context, forceRefresh bool) (jwk.Set, error) {
if !forceRefresh {
a.cache.mu.RLock()
if time.Since(a.cache.cachedAt) < cacheDuration && a.cache.jwks.Len() > 0 {
cached := a.cache.jwks
a.cache.mu.RUnlock()
return cached, nil
}
a.cache.mu.RUnlock()
}
req.Header.Set("Accept", "application/json")
httpClient := &http.Client{
Timeout: 10 * time.Second,
fetchCtx := ctx
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
fetchCtx, cancel = context.WithTimeout(ctx, fetchKeyDeadline)
defer cancel()
}
resp, err := httpClient.Do(req)
// Fetch fresh JWKS from auth service
req, err := http.NewRequestWithContext(fetchCtx, http.MethodGet, a.jwksURL, nil)
if err != nil {
return nil, err
return nil, errors.Wrap(errFetchJWKS, err)
}
req.Header.Set("Accept", acceptHeader)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, errors.Wrap(errFetchJWKS, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errFetchJWKS
// Read error body for better diagnostics
body, _ := io.ReadAll(io.LimitReader(resp.Body, errorBodyBytes))
return nil, errors.Wrap(errFetchJWKS, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
return nil, errors.Wrap(errFetchJWKS, err)
}
set, err := jwk.Parse(data)
if err != nil {
return nil, err
return nil, errors.Wrap(errFetchJWKS, err)
}
jwksCache.Lock()
jwksCache.jwks = set
jwksCache.cachedAt = time.Now()
jwksCache.Unlock()
a.cache.mu.Lock()
a.cache.jwks = set
a.cache.cachedAt = time.Now()
a.cache.mu.Unlock()
return set, nil
}
@@ -160,6 +202,8 @@ func validateToken(token string, jwks jwk.Set) (jwt.Token, error) {
}
return nil, err
}
// Validate issuer
validator := jwt.ValidatorFunc(func(_ context.Context, t jwt.Token) jwt.ValidationError {
if t.Issuer() != issuerName {
return jwt.NewValidationError(errInvalidIssuer)
+336
View File
@@ -0,0 +1,336 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package jwks_test
import (
"crypto/ed25519"
"crypto/rand"
"testing"
"time"
"github.com/absmach/supermq/auth"
smqjwt "github.com/absmach/supermq/auth/tokenizer/util"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
validIssuer = "supermq.auth"
invalidIssuer = "invalid.issuer"
userID = "user123"
)
func TestValidateToken(t *testing.T) {
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
kid := "test-key-id"
privateJwk, err := jwk.FromRaw(privateKey)
require.NoError(t, err, "Parsing JWKS private key should work")
require.NoError(t, privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, privateJwk.Set(jwk.KeyIDKey, kid))
publicJwk, err := jwk.FromRaw(publicKey)
require.NoError(t, err, "Parsing JWKS public key should work")
require.NoError(t, publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, publicJwk.Set(jwk.KeyIDKey, kid))
jwksSet := jwk.NewSet()
require.NoError(t, jwksSet.AddKey(publicJwk), "Creation of JWKS should succeed")
cases := []struct {
name string
issuer string
expiry time.Time
expectErr bool
}{
{
name: "valid token",
issuer: validIssuer,
expiry: time.Now().Add(1 * time.Hour),
expectErr: false,
},
{
name: "expired token",
issuer: validIssuer,
expiry: time.Now().Add(-1 * time.Hour),
expectErr: true,
},
{
name: "invalid issuer",
issuer: invalidIssuer,
expiry: time.Now().Add(1 * time.Hour),
expectErr: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
builder := jwt.NewBuilder()
builder.Issuer(tc.issuer)
builder.IssuedAt(time.Now())
builder.Expiration(tc.expiry)
builder.Subject(userID)
builder.Claim(smqjwt.TokenType, auth.AccessKey)
builder.Claim(smqjwt.RoleField, auth.UserRole)
builder.Claim(smqjwt.VerifiedField, true)
token, err := builder.Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk))
require.NoError(t, err)
parsedToken, err := jwt.Parse(
signedToken,
jwt.WithValidate(true),
jwt.WithKeySet(jwksSet),
)
if tc.expectErr {
// For expired token, parsing will fail
// For invalid issuer, parsing succeeds but we need to check issuer
if tc.issuer != validIssuer && err == nil {
if parsedToken.Issuer() != validIssuer {
assert.NotEqual(t, validIssuer, parsedToken.Issuer())
} else {
assert.Fail(t, "Expected invalid issuer error")
}
} else {
assert.Error(t, err)
}
} else {
assert.NoError(t, err)
assert.NotNil(t, parsedToken)
assert.Equal(t, tc.issuer, parsedToken.Issuer())
}
})
}
}
func TestMultiKeyJWKS(t *testing.T) {
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, priv2, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
activeKID := "key-active"
retiringKID := "key-retiring"
privateJwk1, err := jwk.FromRaw(priv1)
require.NoError(t, err, "Parsing JWKS private key 1 should work")
require.NoError(t, privateJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, privateJwk1.Set(jwk.KeyIDKey, activeKID))
publicJwk1, err := jwk.FromRaw(pub1)
require.NoError(t, err, "Parsing JWKS public key 1 should work")
require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, activeKID))
privateJwk2, err := jwk.FromRaw(priv2)
require.NoError(t, err, "Parsing JWKS private key 2 should work")
require.NoError(t, privateJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, privateJwk2.Set(jwk.KeyIDKey, retiringKID))
publicJwk2, err := jwk.FromRaw(pub2)
require.NoError(t, err, "Parsing JWKS public key 2 should work")
require.NoError(t, publicJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, publicJwk2.Set(jwk.KeyIDKey, retiringKID))
// Create JWKS set with both keys (simulating rotation period)
jwksSet := jwk.NewSet()
require.NoError(t, jwksSet.AddKey(publicJwk1))
require.NoError(t, jwksSet.AddKey(publicJwk2))
assert.Equal(t, 2, jwksSet.Len(), "JWKS should contain both keys")
cases := []struct {
name string
privateKey jwk.Key
kid string
issuer string
expiry time.Time
expectErr bool
}{
{
name: "token signed with active key",
privateKey: privateJwk1,
kid: activeKID,
issuer: validIssuer,
expiry: time.Now().Add(1 * time.Hour),
expectErr: false,
},
{
name: "token signed with retiring key",
privateKey: privateJwk2,
kid: retiringKID,
issuer: validIssuer,
expiry: time.Now().Add(1 * time.Hour),
expectErr: false,
},
{
name: "expired token with active key",
privateKey: privateJwk1,
kid: activeKID,
issuer: validIssuer,
expiry: time.Now().Add(-1 * time.Hour),
expectErr: true,
},
{
name: "expired token with retiring key",
privateKey: privateJwk2,
kid: retiringKID,
issuer: validIssuer,
expiry: time.Now().Add(-1 * time.Hour),
expectErr: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
builder := jwt.NewBuilder()
builder.Issuer(tc.issuer)
builder.IssuedAt(time.Now())
builder.Expiration(tc.expiry)
builder.Subject(userID)
builder.Claim(smqjwt.TokenType, auth.AccessKey)
builder.Claim(smqjwt.RoleField, auth.UserRole)
builder.Claim(smqjwt.VerifiedField, true)
token, err := builder.Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, tc.privateKey))
require.NoError(t, err)
// Parse the token using the multi-key JWKS set
parsedToken, err := jwt.Parse(
signedToken,
jwt.WithValidate(true),
jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)),
)
if tc.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, parsedToken)
assert.Equal(t, tc.issuer, parsedToken.Issuer())
assert.Equal(t, userID, parsedToken.Subject())
}
})
}
}
func TestKeyIDMatching(t *testing.T) {
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err, "Generating key par 1 should succeed")
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err, "Generating key par 2 should succeed")
kid1 := "key-1"
kid2 := "key-2"
privateJwk1, err := jwk.FromRaw(priv1)
require.NoError(t, err, "Parsing JWKS private key 1 should work")
require.NoError(t, privateJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, privateJwk1.Set(jwk.KeyIDKey, kid1))
publicJwk1, err := jwk.FromRaw(pub1)
require.NoError(t, err, "Parsing JWKS public key 1 should work")
require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, kid1))
publicJwk2, err := jwk.FromRaw(pub2)
require.NoError(t, err, "Parsing JWKS public key 2 should work")
require.NoError(t, publicJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, publicJwk2.Set(jwk.KeyIDKey, kid2))
jwksSet := jwk.NewSet()
require.NoError(t, jwksSet.AddKey(publicJwk1), "Adding public key 1 to JWKS set should succeed")
require.NoError(t, jwksSet.AddKey(publicJwk2), "Adding public key 2 to JWKS set should succeed")
// Create token signed with key-1
builder := jwt.NewBuilder()
builder.Issuer(validIssuer)
builder.IssuedAt(time.Now())
builder.Expiration(time.Now().Add(1 * time.Hour))
builder.Subject(userID)
token, err := builder.Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk1))
require.NoError(t, err)
parsedToken, err := jwt.Parse(
signedToken,
jwt.WithValidate(true),
jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)),
)
require.NoError(t, err)
assert.NotNil(t, parsedToken)
assert.Equal(t, validIssuer, parsedToken.Issuer())
headers, err := jws.Parse(signedToken)
require.NoError(t, err)
sigs := headers.Signatures()
require.Len(t, sigs, 1, "JWT should have exactly one signature")
kidValue := sigs[0].ProtectedHeaders().KeyID()
assert.Equal(t, kid1, kidValue, "JWT kid header should match signing key")
}
func TestUnknownKeyID(t *testing.T) {
pub1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err, "Generating public key 1 should succeed")
_, priv2, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err, "Generating private key 2 should succeed")
kid1 := "known-key"
kid2 := "unknown-key"
publicJwk1, err := jwk.FromRaw(pub1)
require.NoError(t, err, "Parsing JWKS public key 1 should work")
require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, kid1))
privateJwk2, err := jwk.FromRaw(priv2)
require.NoError(t, err, "Parsing JWKS private key 2 should work")
require.NoError(t, privateJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA))
require.NoError(t, privateJwk2.Set(jwk.KeyIDKey, kid2))
jwksSet := jwk.NewSet()
require.NoError(t, jwksSet.AddKey(publicJwk1), "Adding public key 1 to set should succeed")
// Create token signed with unknown key-2
builder := jwt.NewBuilder()
builder.Issuer(validIssuer)
builder.IssuedAt(time.Now())
builder.Expiration(time.Now().Add(1 * time.Hour))
builder.Subject(userID)
token, err := builder.Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk2))
require.NoError(t, err)
_, err = jwt.Parse(
signedToken,
jwt.WithValidate(true),
jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)),
)
assert.Error(t, err, "Should fail when token's kid is not in JWKS")
}
-1
View File
@@ -58,7 +58,6 @@ packages:
Cache:
Hasher:
KeyRepository:
KeyManager:
Tokenizer:
PATS:
PATSRepository: