mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
NOISSUE - Improve JWKS (#3301)
Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
@@ -4,8 +4,6 @@
|
||||
package keys_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -21,11 +19,8 @@ import (
|
||||
"github.com/absmach/supermq/auth/mocks"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -323,17 +318,17 @@ func TestRetrieveJWKS(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
svcRes []auth.JWK
|
||||
svcRes []auth.PublicKeyInfo
|
||||
status int
|
||||
}{
|
||||
{
|
||||
desc: "retrieve JWKS with keys",
|
||||
svcRes: []auth.JWK{newJWK(t), newJWK(t)},
|
||||
svcRes: []auth.PublicKeyInfo{newPublicKeyInfo(), newPublicKeyInfo()},
|
||||
status: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "retrieve empty JWKS",
|
||||
svcRes: []auth.JWK{},
|
||||
svcRes: []auth.PublicKeyInfo{},
|
||||
status: http.StatusOK,
|
||||
},
|
||||
}
|
||||
@@ -352,14 +347,13 @@ func TestRetrieveJWKS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newJWK(t *testing.T) auth.JWK {
|
||||
pKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err))
|
||||
jwkKey, err := jwk.FromRaw(&pKey.PublicKey)
|
||||
require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err))
|
||||
err = jwkKey.Set(jwk.KeyIDKey, "test-key-id")
|
||||
require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err))
|
||||
err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String())
|
||||
require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err))
|
||||
return auth.NewJWK(jwkKey)
|
||||
func newPublicKeyInfo() auth.PublicKeyInfo {
|
||||
return auth.PublicKeyInfo{
|
||||
KeyID: "test-key-id",
|
||||
KeyType: "OKP",
|
||||
Algorithm: "EdDSA",
|
||||
Use: "sig",
|
||||
Curve: "Ed25519",
|
||||
X: "base64url-encoded-public-key",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -76,20 +75,17 @@ func (res revokeKeyRes) Empty() bool {
|
||||
}
|
||||
|
||||
type retrieveJWKSRes struct {
|
||||
Keys []auth.JWK `json:"-"`
|
||||
CacheMaxAge int `json:"-"`
|
||||
CacheStaleWhileRevalidate int `json:"-"`
|
||||
Keys []auth.PublicKeyInfo `json:"-"`
|
||||
CacheMaxAge int `json:"-"`
|
||||
CacheStaleWhileRevalidate int `json:"-"`
|
||||
}
|
||||
|
||||
func (res retrieveJWKSRes) MarshalJSON() ([]byte, error) {
|
||||
set := jwk.NewSet()
|
||||
for _, k := range res.Keys {
|
||||
if err := set.AddKey(k.Key()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
type jwksResponse struct {
|
||||
Keys []auth.PublicKeyInfo `json:"keys"`
|
||||
}
|
||||
|
||||
return json.Marshal(set)
|
||||
return json.Marshal(jwksResponse{Keys: res.Keys})
|
||||
}
|
||||
|
||||
func (res retrieveJWKSRes) Code() int {
|
||||
|
||||
@@ -1,341 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jwt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
authjwt "github.com/absmach/supermq/auth/jwt"
|
||||
"github.com/absmach/supermq/auth/mocks"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenType = "type"
|
||||
roleField = "role"
|
||||
VerifiedField = "verified"
|
||||
issuerName = "supermq.auth"
|
||||
)
|
||||
|
||||
var (
|
||||
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
keyManager = new(mocks.KeyManager)
|
||||
)
|
||||
|
||||
func TestIssue(t *testing.T) {
|
||||
tokenizer := authjwt.New(keyManager)
|
||||
|
||||
validKey := key()
|
||||
signedToken, _, err := signToken(issuerName, validKey, false)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
key auth.Key
|
||||
managerReq jwt.Token
|
||||
managerResp []byte
|
||||
managerErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "issue new token",
|
||||
key: validKey,
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token with OAuth token",
|
||||
key: auth.Key{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.AccessKey,
|
||||
Subject: testsutil.GenerateUUID(t),
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
|
||||
},
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token without a domain",
|
||||
key: auth.Key{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.AccessKey,
|
||||
Subject: testsutil.GenerateUUID(t),
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token without a subject",
|
||||
key: auth.Key{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.AccessKey,
|
||||
Subject: "",
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token without type",
|
||||
key: auth.Key{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.KeyType(auth.InvitationKey + 1),
|
||||
Subject: testsutil.GenerateUUID(t),
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token without a domain and subject",
|
||||
key: auth.Key{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.AccessKey,
|
||||
Subject: "",
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
|
||||
},
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token with failed to sign jwt",
|
||||
key: validKey,
|
||||
managerErr: svcerr.ErrAuthentication,
|
||||
err: authjwt.ErrSignJWT,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tc.managerReq = newToken(issuerName, tc.key)
|
||||
kmCall := keyManager.On("SignJWT", tc.managerReq).Return(tc.managerResp, tc.managerErr)
|
||||
tkn, err := tokenizer.Issue(tc.key)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, tkn, fmt.Sprintf("%s expected token, got empty string", tc.desc))
|
||||
}
|
||||
kmCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tokenizer := authjwt.New(keyManager)
|
||||
|
||||
validKey := key()
|
||||
signedTkn, parsedTkn, err := signToken(issuerName, validKey, true)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
|
||||
|
||||
apiKey := key()
|
||||
apiKey.Type = auth.APIKey
|
||||
apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second)
|
||||
apiToken, _, err := signToken(issuerName, apiKey, false)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing api key expected to succeed: %s", err))
|
||||
|
||||
expKey := key()
|
||||
expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second)
|
||||
expToken, _, err := signToken(issuerName, expKey, false)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err))
|
||||
|
||||
emptySubjectKey := key()
|
||||
emptySubjectKey.Subject = ""
|
||||
signedEmptySubjectTkn, parsedEmptySubjectTkn, err := signToken(issuerName, emptySubjectKey, true)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
|
||||
emptyTypeKey := key()
|
||||
emptyTypeKey.Type = auth.KeyType(auth.InvitationKey + 1)
|
||||
emptyTypeToken, _, err := signToken(issuerName, emptyTypeKey, false)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
|
||||
emptyKey := key()
|
||||
emptyKey.Subject = ""
|
||||
|
||||
signedInValidTkn, parsedInvalidTkn, err := signToken("invalid.issuer", key(), true)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
key auth.Key
|
||||
token string
|
||||
managerRes jwt.Token
|
||||
managerErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "parse valid key",
|
||||
key: validKey,
|
||||
token: signedTkn,
|
||||
managerRes: parsedTkn,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "parse invalid key",
|
||||
key: auth.Key{},
|
||||
token: "invalid",
|
||||
managerErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "parse expired key",
|
||||
key: auth.Key{},
|
||||
token: expToken,
|
||||
managerErr: errJWTExpiryKey,
|
||||
err: auth.ErrExpiry,
|
||||
},
|
||||
{
|
||||
desc: "parse expired API key",
|
||||
key: apiKey,
|
||||
token: apiToken,
|
||||
managerErr: errJWTExpiryKey,
|
||||
err: auth.ErrExpiry,
|
||||
},
|
||||
{
|
||||
desc: "parse token with invalid issuer",
|
||||
key: auth.Key{},
|
||||
token: signedInValidTkn,
|
||||
managerRes: parsedInvalidTkn,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "parse token with empty subject",
|
||||
key: emptySubjectKey,
|
||||
token: signedEmptySubjectTkn,
|
||||
managerRes: parsedEmptySubjectTkn,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "parse token with empty type",
|
||||
key: emptyTypeKey,
|
||||
token: emptyTypeToken,
|
||||
managerRes: newToken(issuerName, emptyKey),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
kmCall := keyManager.On("ParseJWT", tc.token).Return(tc.managerRes, tc.managerErr)
|
||||
key, err := tokenizer.Parse(context.Background(), tc.token)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key))
|
||||
}
|
||||
kmCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveJWKS(t *testing.T) {
|
||||
tokenizer := authjwt.New(keyManager)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
keys []auth.JWK
|
||||
retrieveErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve jwks with keys",
|
||||
keys: []auth.JWK{newJWK(t), newJWK(t)},
|
||||
},
|
||||
{
|
||||
desc: "retrieve empty jwks",
|
||||
keys: []auth.JWK{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
kmCall := keyManager.On("PublicJWKS", mock.Anything).Return(tc.keys, tc.retrieveErr)
|
||||
jwks := tokenizer.RetrieveJWKS()
|
||||
assert.Equal(t, tc.keys, jwks, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.keys, jwks))
|
||||
kmCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func key() auth.Key {
|
||||
exp := time.Now().UTC().Add(10 * time.Minute).Round(time.Second)
|
||||
return auth.Key{
|
||||
ID: "66af4a67-3823-438a-abd7-efdb613eaef6",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Role: auth.UserRole,
|
||||
Subject: "66af4a67-3823-438a-abd7-efdb613eaef6",
|
||||
IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second),
|
||||
ExpiresAt: exp,
|
||||
}
|
||||
}
|
||||
|
||||
func newToken(issuerName string, key auth.Key) jwt.Token {
|
||||
builder := jwt.NewBuilder()
|
||||
builder.
|
||||
Issuer(issuerName).
|
||||
IssuedAt(key.IssuedAt).
|
||||
Claim(tokenType, key.Type).
|
||||
Expiration(key.ExpiresAt)
|
||||
builder.Claim(roleField, key.Role)
|
||||
builder.Claim(VerifiedField, key.Verified)
|
||||
if key.Subject != "" {
|
||||
builder.Subject(key.Subject)
|
||||
}
|
||||
if key.ID != "" {
|
||||
builder.JwtID(key.ID)
|
||||
}
|
||||
tkn, _ := builder.Build()
|
||||
return tkn
|
||||
}
|
||||
|
||||
func signToken(issuerName string, key auth.Key, parseToken bool) (string, jwt.Token, error) {
|
||||
tkn := newToken(issuerName, key)
|
||||
pKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
pubKey := &pKey.PublicKey
|
||||
sTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.RS256, pKey))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if !parseToken {
|
||||
return string(sTkn), nil, nil
|
||||
}
|
||||
pTkn, err := jwt.Parse(
|
||||
sTkn,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(jwa.RS256, pubKey),
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return string(sTkn), pTkn, nil
|
||||
}
|
||||
|
||||
func newJWK(t *testing.T) auth.JWK {
|
||||
pKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err))
|
||||
jwkKey, err := jwk.FromRaw(&pKey.PublicKey)
|
||||
require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err))
|
||||
err = jwkKey.Set(jwk.KeyIDKey, "test-key-id")
|
||||
require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err))
|
||||
err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String())
|
||||
require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err))
|
||||
return auth.NewJWK(jwkKey)
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
// errInvalidIssuer is returned when the issuer is not supermq.auth.
|
||||
errInvalidIssuer = errors.New("invalid token issuer value")
|
||||
// errInvalidType is returned when there is no type field.
|
||||
errInvalidType = errors.New("invalid token type")
|
||||
// errInvalidRole is returned when the role is invalid.
|
||||
errInvalidRole = errors.New("invalid role")
|
||||
// errInvalidVerified is returned when the verified is invalid.
|
||||
errInvalidVerified = errors.New("invalid verified")
|
||||
// errJWTExpiryKey is used to check if the token is expired.
|
||||
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
// ErrSignJWT indicates an error in signing jwt token.
|
||||
ErrSignJWT = errors.New("failed to sign jwt token")
|
||||
// ErrValidateJWTToken indicates a failure to validate JWT token.
|
||||
ErrValidateJWTToken = errors.New("failed to validate jwt token")
|
||||
// ErrJSONHandle indicates an error in handling JSON.
|
||||
ErrJSONHandle = errors.New("failed to perform operation JSON")
|
||||
)
|
||||
|
||||
const (
|
||||
issuerName = "supermq.auth"
|
||||
tokenType = "type"
|
||||
RoleField = "role"
|
||||
VerifiedField = "verified"
|
||||
patPrefix = "pat"
|
||||
)
|
||||
|
||||
type tokenizer struct {
|
||||
keyManager auth.KeyManager
|
||||
}
|
||||
|
||||
var _ auth.Tokenizer = (*tokenizer)(nil)
|
||||
|
||||
// New instantiates an implementation of Tokenizer service.
|
||||
func New(keyManager auth.KeyManager) auth.Tokenizer {
|
||||
return &tokenizer{
|
||||
keyManager: keyManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (tok *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
builder := jwt.NewBuilder()
|
||||
builder.
|
||||
Issuer(issuerName).
|
||||
IssuedAt(key.IssuedAt).
|
||||
Claim(tokenType, key.Type).
|
||||
Expiration(key.ExpiresAt)
|
||||
builder.Claim(RoleField, key.Role)
|
||||
builder.Claim(VerifiedField, key.Verified)
|
||||
if key.Subject != "" {
|
||||
builder.Subject(key.Subject)
|
||||
}
|
||||
if key.ID != "" {
|
||||
builder.JwtID(key.ID)
|
||||
}
|
||||
tkn, err := builder.Build()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
signedTkn, err := tok.keyManager.SignJWT(tkn)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(ErrSignJWT, err)
|
||||
}
|
||||
return string(signedTkn), nil
|
||||
}
|
||||
|
||||
func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) {
|
||||
if len(token) >= 3 && token[:3] == patPrefix {
|
||||
return auth.Key{Type: auth.PersonalAccessToken}, nil
|
||||
}
|
||||
|
||||
tkn, err := tok.validateToken(token)
|
||||
if err != nil {
|
||||
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
key, err := ToKey(tkn)
|
||||
if err != nil {
|
||||
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (tok *tokenizer) validateToken(token string) (jwt.Token, error) {
|
||||
tkn, err := tok.keyManager.ParseJWT(token)
|
||||
if err != nil {
|
||||
if errors.Contains(err, errJWTExpiryKey) {
|
||||
return nil, auth.ErrExpiry
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
validator := jwt.ValidatorFunc(func(_ context.Context, t jwt.Token) jwt.ValidationError {
|
||||
if t.Issuer() != issuerName {
|
||||
return jwt.NewValidationError(errInvalidIssuer)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := jwt.Validate(tkn, jwt.WithValidator(validator)); err != nil {
|
||||
return nil, errors.Wrap(ErrValidateJWTToken, err)
|
||||
}
|
||||
|
||||
return tkn, nil
|
||||
}
|
||||
|
||||
func (tok *tokenizer) RetrieveJWKS() []auth.JWK {
|
||||
return tok.keyManager.PublicJWKS()
|
||||
}
|
||||
|
||||
func ToKey(tkn jwt.Token) (auth.Key, error) {
|
||||
data, err := json.Marshal(tkn.PrivateClaims())
|
||||
if err != nil {
|
||||
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
|
||||
}
|
||||
var key auth.Key
|
||||
if err := json.Unmarshal(data, &key); err != nil {
|
||||
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
|
||||
}
|
||||
|
||||
tType, ok := tkn.Get(tokenType)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
kType, ok := tType.(float64)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
kt := auth.KeyType(kType)
|
||||
if !kt.Validate() {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
|
||||
tRole, ok := tkn.Get(RoleField)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
kRole, ok := tRole.(float64)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
|
||||
tVerified, ok := tkn.Get(VerifiedField)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidVerified
|
||||
}
|
||||
kVerified, ok := tVerified.(bool)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidVerified
|
||||
}
|
||||
|
||||
kr := auth.Role(kRole)
|
||||
if !kr.Validate() {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
|
||||
key.ID = tkn.JwtID()
|
||||
key.Type = auth.KeyType(kType)
|
||||
key.Role = auth.Role(kRole)
|
||||
key.Issuer = tkn.Issuer()
|
||||
key.Subject = tkn.Subject()
|
||||
key.IssuedAt = tkn.IssuedAt()
|
||||
key.ExpiresAt = tkn.Expiration()
|
||||
key.Verified = kVerified
|
||||
|
||||
return key, nil
|
||||
}
|
||||
+35
-25
@@ -4,47 +4,57 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm")
|
||||
ErrInvalidSymmetricKey = errors.New("invalid symmetric key")
|
||||
ErrPublicKeysNotSupported = errors.New("public keys not supported for symmetric algorithm")
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key.
|
||||
type JWK struct {
|
||||
key jwk.Key
|
||||
// PublicKeyInfo represents a public key for external distribution via JWKS.
|
||||
// This follows RFC 7517 (JSON Web Key) specification.
|
||||
type PublicKeyInfo struct {
|
||||
KeyID string `json:"kid"`
|
||||
KeyType string `json:"kty"`
|
||||
Algorithm string `json:"alg"`
|
||||
Use string `json:"use,omitempty"`
|
||||
|
||||
// EdDSA (Ed25519) fields
|
||||
Curve string `json:"crv,omitempty"`
|
||||
X string `json:"x,omitempty"`
|
||||
|
||||
// Future: RSA fields (n, e), ECDSA fields (x, y, crv), etc.
|
||||
}
|
||||
|
||||
// NewJWK creates a new JWK from a jwk.Key.
|
||||
func NewJWK(key jwk.Key) JWK {
|
||||
return JWK{key: key}
|
||||
}
|
||||
|
||||
// Key returns the underlying jwk.Key.
|
||||
func (j JWK) Key() jwk.Key {
|
||||
return j.key
|
||||
}
|
||||
|
||||
// KeyManager represents a manager for JWT keys.
|
||||
type KeyManager interface {
|
||||
SignJWT(token jwt.Token) ([]byte, error)
|
||||
|
||||
ParseJWT(token string) (jwt.Token, error)
|
||||
|
||||
PublicJWKS() []JWK
|
||||
// Tokenizer handles token creation and verification for authentication.
|
||||
// Implementations manage underlying cryptographic operations and key distribution.
|
||||
type Tokenizer interface {
|
||||
// Issue creates a signed token string from the given key claims.
|
||||
Issue(key Key) (token string, err error)
|
||||
|
||||
// Parse verifies and parses a token string (JWT or PAT), returning the extracted claims.
|
||||
// For PAT tokens (prefix "pat"), returns a Key with Type set to PersonalAccessToken.
|
||||
// For JWT tokens, performs cryptographic verification and returns the parsed claims.
|
||||
Parse(ctx context.Context, token string) (key Key, err error)
|
||||
|
||||
// RetrieveJWKS returns public keys for distribution via JWKS endpoint.
|
||||
// Returns ErrPublicKeysNotSupported for symmetric tokenizers (HMAC).
|
||||
RetrieveJWKS() ([]PublicKeyInfo, error)
|
||||
}
|
||||
|
||||
// IsSymmetricAlgorithm determines if the given algorithm is symmetric (HMAC-based).
|
||||
// Returns true for HMAC algorithms (HS256, HS384, HS512).
|
||||
// Returns false for asymmetric algorithms (EdDSA).
|
||||
// Returns error for unsupported algorithms.
|
||||
func IsSymmetricAlgorithm(alg string) (bool, error) {
|
||||
switch alg {
|
||||
case "HS256", "HS384", "HS512":
|
||||
return true, nil
|
||||
case "EdDSA":
|
||||
return false, nil
|
||||
case "HS256", "HS384", "HS512":
|
||||
return true, nil
|
||||
default:
|
||||
return false, ErrUnsupportedKeyAlgorithm
|
||||
}
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package asymmetric
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
errLoadingPrivateKey = errors.New("failed to load private key")
|
||||
errInvalidKeySize = errors.New("invalid ED25519 key size")
|
||||
errParsingPrivateKey = errors.New("failed to parse private key")
|
||||
errInvalidKeyType = errors.New("private key is not ED25519")
|
||||
errGeneratingKID = errors.New("failed to generate key ID")
|
||||
)
|
||||
|
||||
type manager struct {
|
||||
privateKey jwk.Key
|
||||
publicKey jwk.Key
|
||||
kid string
|
||||
}
|
||||
|
||||
var _ auth.KeyManager = (*manager)(nil)
|
||||
|
||||
// NewKeyManager creates a new asymmetric key manager that loads the private key from a file.
|
||||
func NewKeyManager(privateKeyPath string, idProvider supermq.IDProvider) (auth.KeyManager, error) {
|
||||
kid, err := idProvider.ID()
|
||||
if err != nil {
|
||||
return nil, errors.Join(errGeneratingKID, err)
|
||||
}
|
||||
|
||||
privateJwk, publicJwk, err := loadKeyPair(privateKeyPath, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manager{
|
||||
privateKey: privateJwk,
|
||||
publicKey: publicJwk,
|
||||
kid: kid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (km *manager) SignJWT(token jwt.Token) ([]byte, error) {
|
||||
return jwt.Sign(token, jwt.WithKey(jwa.EdDSA, km.privateKey))
|
||||
}
|
||||
|
||||
func (km *manager) ParseJWT(token string) (jwt.Token, error) {
|
||||
set := jwk.NewSet()
|
||||
if err := set.AddKey(km.publicKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tkn, err := jwt.Parse(
|
||||
[]byte(token),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(set, jws.WithInferAlgorithmFromKey(true)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tkn, nil
|
||||
}
|
||||
|
||||
func (km *manager) PublicJWKS() []auth.JWK {
|
||||
return []auth.JWK{auth.NewJWK(km.publicKey)}
|
||||
}
|
||||
|
||||
func loadKeyPair(privateKeyPath string, kid string) (jwk.Key, jwk.Key, error) {
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(errLoadingPrivateKey, err)
|
||||
}
|
||||
|
||||
var privateKey ed25519.PrivateKey
|
||||
block, _ := pem.Decode(privateKeyBytes)
|
||||
switch {
|
||||
case block != nil:
|
||||
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(errParsingPrivateKey, err)
|
||||
}
|
||||
var ok bool
|
||||
privateKey, ok = parsedKey.(ed25519.PrivateKey)
|
||||
if !ok {
|
||||
return nil, nil, errInvalidKeyType
|
||||
}
|
||||
default:
|
||||
if len(privateKeyBytes) != ed25519.PrivateKeySize {
|
||||
return nil, nil, errInvalidKeySize
|
||||
}
|
||||
privateKey = ed25519.PrivateKey(privateKeyBytes)
|
||||
}
|
||||
|
||||
publicKey := privateKey.Public().(ed25519.PublicKey)
|
||||
|
||||
privateJwk, err := jwk.FromRaw(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := privateJwk.Set(jwk.KeyIDKey, kid); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
publicJwk, err := jwk.FromRaw(publicKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := publicJwk.Set(jwk.KeyIDKey, kid); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return privateJwk, publicJwk, nil
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package symmetric
|
||||
|
||||
import (
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
type manager struct {
|
||||
algorithm jwa.KeyAlgorithm
|
||||
secret []byte
|
||||
}
|
||||
|
||||
var _ auth.KeyManager = (*manager)(nil)
|
||||
|
||||
func NewKeyManager(algorithm string, secret []byte) (auth.KeyManager, error) {
|
||||
alg := jwa.KeyAlgorithmFrom(algorithm)
|
||||
if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok {
|
||||
return nil, auth.ErrUnsupportedKeyAlgorithm
|
||||
}
|
||||
if len(secret) == 0 {
|
||||
return nil, auth.ErrInvalidSymmetricKey
|
||||
}
|
||||
return &manager{
|
||||
secret: secret,
|
||||
algorithm: alg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (km *manager) SignJWT(token jwt.Token) ([]byte, error) {
|
||||
return jwt.Sign(token, jwt.WithKey(km.algorithm, km.secret))
|
||||
}
|
||||
|
||||
func (km *manager) ParseJWT(token string) (jwt.Token, error) {
|
||||
return jwt.Parse(
|
||||
[]byte(token),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(km.algorithm, km.secret),
|
||||
)
|
||||
}
|
||||
|
||||
func (km *manager) PublicJWKS() []auth.JWK {
|
||||
return nil
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func (lm *loggingMiddleware) Identify(ctx context.Context, token string) (id aut
|
||||
return lm.svc.Identify(ctx, token)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.JWK) {
|
||||
func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.PublicKeyInfo) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
|
||||
@@ -67,7 +67,7 @@ func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (auth.K
|
||||
return ms.svc.Identify(ctx, token)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) RetrieveJWKS() []auth.JWK {
|
||||
func (ms *metricsMiddleware) RetrieveJWKS() []auth.PublicKeyInfo {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "retrieve_jwks").Add(1)
|
||||
ms.latency.With("method", "retrieve_jwks").Observe(time.Since(begin).Seconds())
|
||||
|
||||
@@ -61,7 +61,7 @@ func (tm *tracingMiddleware) Identify(ctx context.Context, token string) (auth.K
|
||||
return tm.svc.Identify(ctx, token)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) RetrieveJWKS() []auth.JWK {
|
||||
func (tm *tracingMiddleware) RetrieveJWKS() []auth.PublicKeyInfo {
|
||||
return tm.svc.RetrieveJWKS()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewKeyManager creates a new instance of KeyManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewKeyManager(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *KeyManager {
|
||||
mock := &KeyManager{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// KeyManager is an autogenerated mock type for the KeyManager type
|
||||
type KeyManager struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type KeyManager_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *KeyManager) EXPECT() *KeyManager_Expecter {
|
||||
return &KeyManager_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// ParseJWT provides a mock function for the type KeyManager
|
||||
func (_mock *KeyManager) ParseJWT(token string) (jwt.Token, error) {
|
||||
ret := _mock.Called(token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ParseJWT")
|
||||
}
|
||||
|
||||
var r0 jwt.Token
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(string) (jwt.Token, error)); ok {
|
||||
return returnFunc(token)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(string) jwt.Token); ok {
|
||||
r0 = returnFunc(token)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(jwt.Token)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = returnFunc(token)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// KeyManager_ParseJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ParseJWT'
|
||||
type KeyManager_ParseJWT_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ParseJWT is a helper method to define mock.On call
|
||||
// - token string
|
||||
func (_e *KeyManager_Expecter) ParseJWT(token interface{}) *KeyManager_ParseJWT_Call {
|
||||
return &KeyManager_ParseJWT_Call{Call: _e.mock.On("ParseJWT", token)}
|
||||
}
|
||||
|
||||
func (_c *KeyManager_ParseJWT_Call) Run(run func(token string)) *KeyManager_ParseJWT_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 string
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_ParseJWT_Call) Return(token1 jwt.Token, err error) *KeyManager_ParseJWT_Call {
|
||||
_c.Call.Return(token1, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_ParseJWT_Call) RunAndReturn(run func(token string) (jwt.Token, error)) *KeyManager_ParseJWT_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// PublicJWKS provides a mock function for the type KeyManager
|
||||
func (_mock *KeyManager) PublicJWKS() []auth.JWK {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for PublicJWKS")
|
||||
}
|
||||
|
||||
var r0 []auth.JWK
|
||||
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]auth.JWK)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// KeyManager_PublicJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublicJWKS'
|
||||
type KeyManager_PublicJWKS_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// PublicJWKS is a helper method to define mock.On call
|
||||
func (_e *KeyManager_Expecter) PublicJWKS() *KeyManager_PublicJWKS_Call {
|
||||
return &KeyManager_PublicJWKS_Call{Call: _e.mock.On("PublicJWKS")}
|
||||
}
|
||||
|
||||
func (_c *KeyManager_PublicJWKS_Call) Run(run func()) *KeyManager_PublicJWKS_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_PublicJWKS_Call) Return(jWKs []auth.JWK) *KeyManager_PublicJWKS_Call {
|
||||
_c.Call.Return(jWKs)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_PublicJWKS_Call) RunAndReturn(run func() []auth.JWK) *KeyManager_PublicJWKS_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SignJWT provides a mock function for the type KeyManager
|
||||
func (_mock *KeyManager) SignJWT(token jwt.Token) ([]byte, error) {
|
||||
ret := _mock.Called(token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SignJWT")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(jwt.Token) ([]byte, error)); ok {
|
||||
return returnFunc(token)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(jwt.Token) []byte); ok {
|
||||
r0 = returnFunc(token)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(jwt.Token) error); ok {
|
||||
r1 = returnFunc(token)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// KeyManager_SignJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignJWT'
|
||||
type KeyManager_SignJWT_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SignJWT is a helper method to define mock.On call
|
||||
// - token jwt.Token
|
||||
func (_e *KeyManager_Expecter) SignJWT(token interface{}) *KeyManager_SignJWT_Call {
|
||||
return &KeyManager_SignJWT_Call{Call: _e.mock.On("SignJWT", token)}
|
||||
}
|
||||
|
||||
func (_c *KeyManager_SignJWT_Call) Run(run func(token jwt.Token)) *KeyManager_SignJWT_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 jwt.Token
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(jwt.Token)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_SignJWT_Call) Return(bytes []byte, err error) *KeyManager_SignJWT_Call {
|
||||
_c.Call.Return(bytes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_SignJWT_Call) RunAndReturn(run func(token jwt.Token) ([]byte, error)) *KeyManager_SignJWT_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -1029,19 +1029,19 @@ func (_c *Service_ResetPATSecret_Call) RunAndReturn(run func(ctx context.Context
|
||||
}
|
||||
|
||||
// RetrieveJWKS provides a mock function for the type Service
|
||||
func (_mock *Service) RetrieveJWKS() []auth.JWK {
|
||||
func (_mock *Service) RetrieveJWKS() []auth.PublicKeyInfo {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveJWKS")
|
||||
}
|
||||
|
||||
var r0 []auth.JWK
|
||||
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
|
||||
var r0 []auth.PublicKeyInfo
|
||||
if returnFunc, ok := ret.Get(0).(func() []auth.PublicKeyInfo); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]auth.JWK)
|
||||
r0 = ret.Get(0).([]auth.PublicKeyInfo)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
@@ -1064,12 +1064,12 @@ func (_c *Service_RetrieveJWKS_Call) Run(run func()) *Service_RetrieveJWKS_Call
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Service_RetrieveJWKS_Call {
|
||||
_c.Call.Return(jWKs)
|
||||
func (_c *Service_RetrieveJWKS_Call) Return(publicKeyInfos []auth.PublicKeyInfo) *Service_RetrieveJWKS_Call {
|
||||
_c.Call.Return(publicKeyInfos)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Service_RetrieveJWKS_Call {
|
||||
func (_c *Service_RetrieveJWKS_Call) RunAndReturn(run func() []auth.PublicKeyInfo) *Service_RetrieveJWKS_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
+17
-8
@@ -169,22 +169,31 @@ func (_c *Tokenizer_Parse_Call) RunAndReturn(run func(ctx context.Context, token
|
||||
}
|
||||
|
||||
// RetrieveJWKS provides a mock function for the type Tokenizer
|
||||
func (_mock *Tokenizer) RetrieveJWKS() []auth.JWK {
|
||||
func (_mock *Tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveJWKS")
|
||||
}
|
||||
|
||||
var r0 []auth.JWK
|
||||
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
|
||||
var r0 []auth.PublicKeyInfo
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func() ([]auth.PublicKeyInfo, error)); ok {
|
||||
return returnFunc()
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func() []auth.PublicKeyInfo); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]auth.JWK)
|
||||
r0 = ret.Get(0).([]auth.PublicKeyInfo)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
if returnFunc, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = returnFunc()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Tokenizer_RetrieveJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveJWKS'
|
||||
@@ -204,12 +213,12 @@ func (_c *Tokenizer_RetrieveJWKS_Call) Run(run func()) *Tokenizer_RetrieveJWKS_C
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Tokenizer_RetrieveJWKS_Call {
|
||||
_c.Call.Return(jWKs)
|
||||
func (_c *Tokenizer_RetrieveJWKS_Call) Return(publicKeyInfos []auth.PublicKeyInfo, err error) *Tokenizer_RetrieveJWKS_Call {
|
||||
_c.Call.Return(publicKeyInfos, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Tokenizer_RetrieveJWKS_Call {
|
||||
func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() ([]auth.PublicKeyInfo, error)) *Tokenizer_RetrieveJWKS_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -186,6 +186,8 @@ func ParseEntityType(et string) (EntityType, error) {
|
||||
return UsersType, nil
|
||||
case DashboardsStr:
|
||||
return DashboardType, nil
|
||||
case MessagesStr:
|
||||
return MessagesType, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown domain entity type %s", et)
|
||||
}
|
||||
|
||||
+12
-4
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
"github.com/google/uuid"
|
||||
@@ -84,8 +85,8 @@ type Authn interface {
|
||||
// other reason, non-nil error value is returned in response.
|
||||
Identify(ctx context.Context, token string) (Key, error)
|
||||
|
||||
// RetrieveJWKS retrieves a JWKs to validate issued tokens.
|
||||
RetrieveJWKS() []JWK
|
||||
// RetrieveJWKS retrieves public keys to validate issued tokens.
|
||||
RetrieveJWKS() []PublicKeyInfo
|
||||
}
|
||||
|
||||
// Service specifies an API that must be fulfilled by the domain service
|
||||
@@ -201,8 +202,12 @@ func (svc service) Identify(ctx context.Context, token string) (Key, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (svc service) RetrieveJWKS() []JWK {
|
||||
return svc.tokenizer.RetrieveJWKS()
|
||||
func (svc service) RetrieveJWKS() []PublicKeyInfo {
|
||||
keys, err := svc.tokenizer.RetrieveJWKS()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (svc service) Authorize(ctx context.Context, pr policies.Policy) error {
|
||||
@@ -803,6 +808,9 @@ func (svc service) authnAuthzUserPAT(ctx context.Context, token, patID string) (
|
||||
|
||||
_, err = svc.pats.Retrieve(ctx, key.Subject, patID)
|
||||
if err != nil {
|
||||
if errors.Contains(err, repoerr.ErrNotFound) {
|
||||
return Key{}, svcerr.ErrNotFound
|
||||
}
|
||||
return Key{}, errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Tokenizer specifies API for encoding and decoding between string and Key.
|
||||
type Tokenizer interface {
|
||||
// Issue converts API Key to its string representation.
|
||||
Issue(key Key) (token string, err error)
|
||||
|
||||
// Parse extracts API Key data from string token.
|
||||
Parse(ctx context.Context, token string) (key Key, err error)
|
||||
|
||||
// RetrieveJWKS returns the JSON Web Key Set.
|
||||
RetrieveJWKS() []JWK
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
# Asymmetric Tokenizer
|
||||
|
||||
EdDSA (Ed25519) tokenizer with support for zero-downtime key rotation.
|
||||
|
||||
## Features
|
||||
|
||||
- **Single-key mode** - Simple setup with one active key
|
||||
- **Two-key mode** - Active + retiring keys for _zero-downtime rotation_
|
||||
- **JWKS endpoint** - Publishes all valid public keys for token verification
|
||||
|
||||
## Configuration
|
||||
|
||||
The tokenizer uses environment variables to specify key file paths:
|
||||
|
||||
| Environment Variable | Required | Description |
|
||||
| --------------------------------- | -------- | ------------------------------------------------ |
|
||||
| `SMQ_AUTH_KEYS_ACTIVE_KEY_PATH` | Yes | Path to active private key file |
|
||||
| `SMQ_AUTH_KEYS_RETIRING_KEY_PATH` | No | Path to retiring private key file (for rotation) |
|
||||
|
||||
Please note that key names are used as **key IDs (kid)**.
|
||||
|
||||
### Single-Key Mode
|
||||
|
||||
Set only the active key path:
|
||||
|
||||
```bash
|
||||
export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/private.key"
|
||||
```
|
||||
|
||||
The tokenizer will:
|
||||
- Issue new tokens signed with the active key
|
||||
- Verify tokens using the active key
|
||||
- Return one public key in JWKS endpoint
|
||||
|
||||
### Two-Key Mode (Key Rotation)
|
||||
|
||||
Set both active and retiring key paths:
|
||||
|
||||
```bash
|
||||
export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key"
|
||||
export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key"
|
||||
```
|
||||
|
||||
The tokenizer will:
|
||||
- Issue new tokens signed with the active key
|
||||
- Verify tokens using both active and retiring keys
|
||||
- Return both public keys in JWKS endpoint
|
||||
|
||||
## Key Rotation Process
|
||||
|
||||
Zero-downtime key rotation in 3 simple steps:
|
||||
|
||||
### 1. Generate New Key
|
||||
|
||||
```bash
|
||||
openssl genpkey -algorithm Ed25519 -out keys/new.key
|
||||
```
|
||||
|
||||
### 2. Update Environment & Restart
|
||||
|
||||
Move the current active key to retiring position and set the new key as active:
|
||||
|
||||
```bash
|
||||
# Before rotation
|
||||
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/current.key"
|
||||
SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" # No retiring key
|
||||
|
||||
# During rotation (both keys active for grace period)
|
||||
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key"
|
||||
SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/current.key"
|
||||
|
||||
# After rotation (restart service with new config)
|
||||
docker-compose restart auth
|
||||
```
|
||||
|
||||
During the grace period, tokens signed with either key remain valid.
|
||||
|
||||
### 3. Clean Up After Grace Period
|
||||
|
||||
After the grace period expires (typically 7-30 days), remove the retiring key:
|
||||
|
||||
```bash
|
||||
# Remove retiring key configuration
|
||||
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key"
|
||||
SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" # Remove retiring key
|
||||
|
||||
# Restart service
|
||||
docker-compose restart auth
|
||||
|
||||
# Delete old key file
|
||||
rm keys/current.key
|
||||
```
|
||||
|
||||
## Grace Period Recommendations
|
||||
|
||||
**Recommended:** 168 hours (7 days)
|
||||
**Minimum:** 24 hours
|
||||
**Maximum:** 720 hours (30 days)
|
||||
|
||||
The grace period should be longer than your longest-lived access token duration.
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
- Store private keys with `0600` permissions
|
||||
- Use cryptographically secure key generation:
|
||||
```bash
|
||||
openssl genpkey -algorithm Ed25519 -out private.key
|
||||
chmod 600 private.key
|
||||
```
|
||||
- Rotate keys regularly:
|
||||
- Standard environments: every 90 days
|
||||
- High-security environments: every 30 days
|
||||
- Never commit keys to version control
|
||||
- Use secrets management in production (HashiCorp Vault, AWS Secrets Manager, etc.)
|
||||
|
||||
## Example: Complete Rotation
|
||||
|
||||
```bash
|
||||
# Day 0: Normal operation
|
||||
export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2024.pem"
|
||||
export SMQ_AUTH_KEYS_RETIRING_KEY_PATH=""
|
||||
|
||||
# Day 1: Start rotation - generate new key
|
||||
openssl genpkey -algorithm Ed25519 -out ./keys/key-2025.pem
|
||||
chmod 600 ./keys/key-2025.pem
|
||||
|
||||
# Day 1: Update config and restart
|
||||
export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2025.pem"
|
||||
export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/key-2024.pem"
|
||||
docker-compose restart auth
|
||||
|
||||
# Day 8: Grace period expired - remove old key
|
||||
export SMQ_AUTH_KEYS_RETIRING_KEY_PATH=""
|
||||
docker-compose restart auth
|
||||
rm ./keys/key-2024.pem
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Active key not found
|
||||
|
||||
```
|
||||
Error: active key file not found: ./keys/active.key
|
||||
```
|
||||
|
||||
**Solution:** Ensure the file exists and path is correct. Verify `SMQ_AUTH_KEYS_ACTIVE_KEY_PATH` environment variable.
|
||||
|
||||
### Retiring key warning
|
||||
|
||||
If the retiring key path is set but the file is missing or invalid, the tokenizer logs a warning but continues with only the active key:
|
||||
|
||||
```
|
||||
WARN: failed to load retiring key, continuing without it
|
||||
```
|
||||
|
||||
This is by design - a missing retiring key won't prevent startup.
|
||||
|
||||
### Invalid key format
|
||||
|
||||
```
|
||||
Error: failed to parse private key
|
||||
```
|
||||
|
||||
**Solution:** Ensure you're using Ed25519 keys in PEM format (PKCS8).
|
||||
@@ -0,0 +1,160 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package asymmetric_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/auth/tokenizer/asymmetric"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type incrementingIDProvider struct {
|
||||
counter int
|
||||
}
|
||||
|
||||
func (p *incrementingIDProvider) ID() (string, error) {
|
||||
p.counter++
|
||||
return fmt.Sprintf("key-id-%d", p.counter), nil
|
||||
}
|
||||
|
||||
func TestTwoKeyRotation(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
_, activePriv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, retiringPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeKeyPath := filepath.Join(tmpDir, "active.key")
|
||||
retiringKeyPath := filepath.Join(tmpDir, "retiring.key")
|
||||
|
||||
saveKey(t, activePriv, activeKeyPath)
|
||||
saveKey(t, retiringPriv, retiringKeyPath)
|
||||
|
||||
idProvider := &incrementingIDProvider{}
|
||||
tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, newTestLogger())
|
||||
require.NoError(t, err)
|
||||
|
||||
testKey := auth.Key{
|
||||
ID: "test-key",
|
||||
Type: auth.AccessKey,
|
||||
Subject: "user-123",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
Verified: true,
|
||||
}
|
||||
|
||||
token, err := tokenizer.Issue(testKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
verified, err := tokenizer.Parse(context.Background(), token)
|
||||
require.NoError(t, err, "Should work with active token")
|
||||
assert.Equal(t, testKey.Subject, verified.Subject)
|
||||
|
||||
publicKeys, err := tokenizer.RetrieveJWKS()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, publicKeys, 2, "Should return both active and retiring keys")
|
||||
|
||||
keyIDs := make(map[string]bool)
|
||||
for _, pk := range publicKeys {
|
||||
keyIDs[pk.KeyID] = true
|
||||
}
|
||||
assert.Len(t, keyIDs, 2, "Both keys should have unique IDs")
|
||||
}
|
||||
|
||||
func TestSingleKeyMode(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyPath := filepath.Join(tmpDir, "single.key")
|
||||
saveKey(t, privateKey, keyPath)
|
||||
|
||||
idProvider := &mockIDProvider{id: "single-id"}
|
||||
tokenizer, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
|
||||
require.NoError(t, err)
|
||||
|
||||
testKey := auth.Key{
|
||||
ID: "test",
|
||||
Type: auth.AccessKey,
|
||||
Subject: "user",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
}
|
||||
|
||||
token, err := tokenizer.Issue(testKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tokenizer.Parse(context.Background(), token)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeys, err := tokenizer.RetrieveJWKS()
|
||||
require.NoError(t, err, "Should return one active key")
|
||||
assert.Len(t, publicKeys, 1, "Should return only the active key")
|
||||
}
|
||||
|
||||
func TestMissingRetiringKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeKeyPath := filepath.Join(tmpDir, "active.key")
|
||||
saveKey(t, privateKey, activeKeyPath)
|
||||
|
||||
retiringKeyPath := filepath.Join(tmpDir, "nonexistent.key")
|
||||
|
||||
idProvider := &mockIDProvider{id: "test-id"}
|
||||
tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, newTestLogger())
|
||||
require.NoError(t, err, "Should succeed even if retiring key is missing")
|
||||
|
||||
testKey := auth.Key{
|
||||
ID: "test",
|
||||
Type: auth.AccessKey,
|
||||
Subject: "user",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
}
|
||||
|
||||
token, err := tokenizer.Issue(testKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tokenizer.Parse(context.Background(), token)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeys, err := tokenizer.RetrieveJWKS()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, publicKeys, 1, "Should return only active key when retiring key is missing")
|
||||
}
|
||||
|
||||
func saveKey(t *testing.T, privateKey ed25519.PrivateKey, path string) {
|
||||
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pkcs8Key,
|
||||
}
|
||||
|
||||
err = os.WriteFile(path, pem.EncodeToMemory(pemBlock), 0o600)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,246 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package asymmetric
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/auth"
|
||||
smqjwt "github.com/absmach/supermq/auth/tokenizer/util"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
const patPrefix = "pat"
|
||||
|
||||
var (
|
||||
errLoadingPrivateKey = errors.New("failed to load private key")
|
||||
errDuplicateRetiringKeyID = errors.New("retiring key ID matches active key ID")
|
||||
errInvalidKeySize = errors.New("invalid ED25519 key size")
|
||||
errParsingPrivateKey = errors.New("failed to parse private key")
|
||||
errInvalidKeyType = errors.New("private key is not ED25519")
|
||||
errNoValidPublicKeys = errors.New("no valid public keys available")
|
||||
errNoActiveKey = errors.New("active key not loaded")
|
||||
)
|
||||
|
||||
type keyPair struct {
|
||||
id string
|
||||
privateKey jwk.Key
|
||||
publicKey jwk.Key
|
||||
}
|
||||
|
||||
// Tokenizer is safe for concurrent use. Keys are set during construction
|
||||
// and never modified afterward.
|
||||
type tokenizer struct {
|
||||
activeKey *keyPair
|
||||
retiringKey *keyPair // Optional, for key rotation grace period
|
||||
}
|
||||
|
||||
var _ auth.Tokenizer = (*tokenizer)(nil)
|
||||
|
||||
// NewTokenizer creates a new asymmetric tokenizer with active and optionally retiring keys.
|
||||
// activeKeyPath is required. retiringKeyPath is optional (can be empty string).
|
||||
// If retiringKeyPath is provided but the file doesn't exist or is invalid, a warning is logged
|
||||
// but the tokenizer is still created with just the active key.
|
||||
// Key IDs are derived from filenames to ensure consistency across multiple service instances.
|
||||
func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDProvider, logger *slog.Logger) (auth.Tokenizer, error) {
|
||||
activeKID := keyIDFromPath(activeKeyPath)
|
||||
|
||||
activePrivateJwk, activePublicJwk, err := loadKeyPair(activeKeyPath, activeKID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr := &tokenizer{
|
||||
activeKey: &keyPair{
|
||||
id: activeKID,
|
||||
privateKey: activePrivateJwk,
|
||||
publicKey: activePublicJwk,
|
||||
},
|
||||
}
|
||||
|
||||
if retiringKeyPath != "" {
|
||||
retiringKID := keyIDFromPath(retiringKeyPath)
|
||||
if retiringKID == activeKID {
|
||||
return nil, errDuplicateRetiringKeyID
|
||||
}
|
||||
|
||||
retiringPrivateJwk, retiringPublicJwk, err := loadKeyPair(retiringKeyPath, retiringKID)
|
||||
if err != nil {
|
||||
logger.Warn("failed to load retiring key, continuing without it", slog.Any("error", err))
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
mgr.retiringKey = &keyPair{
|
||||
id: retiringKID,
|
||||
privateKey: retiringPrivateJwk,
|
||||
publicKey: retiringPublicJwk,
|
||||
}
|
||||
logger.Info("loaded retiring key for rotation grace period", slog.String("key_id", retiringKID))
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func (km *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
if km.activeKey == nil {
|
||||
return "", errNoActiveKey
|
||||
}
|
||||
|
||||
tkn, err := smqjwt.BuildToken(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
headers := jws.NewHeaders()
|
||||
if err := headers.Set(jwk.KeyIDKey, km.activeKey.id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, km.activeKey.privateKey, jws.WithProtectedHeaders(headers)))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(signedBytes), nil
|
||||
}
|
||||
|
||||
func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
|
||||
if len(tokenString) >= 3 && tokenString[:3] == patPrefix {
|
||||
return auth.Key{Type: auth.PersonalAccessToken}, nil
|
||||
}
|
||||
|
||||
set := jwk.NewSet()
|
||||
if err := set.AddKey(km.activeKey.publicKey); err != nil {
|
||||
return auth.Key{}, err
|
||||
}
|
||||
if km.retiringKey != nil {
|
||||
if err := set.AddKey(km.retiringKey.publicKey); err != nil {
|
||||
return auth.Key{}, err
|
||||
}
|
||||
}
|
||||
|
||||
tkn, err := jwt.Parse(
|
||||
[]byte(tokenString),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(set, jws.WithInferAlgorithmFromKey(true)),
|
||||
)
|
||||
if err != nil {
|
||||
return auth.Key{}, err
|
||||
}
|
||||
|
||||
if tkn.Issuer() != smqjwt.IssuerName {
|
||||
return auth.Key{}, smqjwt.ErrInvalidIssuer
|
||||
}
|
||||
|
||||
return smqjwt.ToKey(tkn)
|
||||
}
|
||||
|
||||
func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
|
||||
publicKeys := make([]auth.PublicKeyInfo, 0, 2)
|
||||
|
||||
if km.activeKey != nil {
|
||||
if pkInfo := extractPublicKeyInfo(km.activeKey); pkInfo != nil {
|
||||
publicKeys = append(publicKeys, *pkInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if km.retiringKey != nil {
|
||||
if pkInfo := extractPublicKeyInfo(km.retiringKey); pkInfo != nil {
|
||||
publicKeys = append(publicKeys, *pkInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if len(publicKeys) == 0 {
|
||||
return nil, errNoValidPublicKeys
|
||||
}
|
||||
|
||||
return publicKeys, nil
|
||||
}
|
||||
|
||||
func extractPublicKeyInfo(kp *keyPair) *auth.PublicKeyInfo {
|
||||
var rawKey ed25519.PublicKey
|
||||
if err := kp.publicKey.Raw(&rawKey); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &auth.PublicKeyInfo{
|
||||
KeyID: kp.id,
|
||||
KeyType: "OKP",
|
||||
Algorithm: "EdDSA",
|
||||
Use: "sig",
|
||||
Curve: "Ed25519",
|
||||
X: base64.RawURLEncoding.EncodeToString(rawKey),
|
||||
}
|
||||
}
|
||||
|
||||
func loadKeyPair(privateKeyPath string, kid string) (jwk.Key, jwk.Key, error) {
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(errLoadingPrivateKey, err)
|
||||
}
|
||||
|
||||
var privateKey ed25519.PrivateKey
|
||||
block, _ := pem.Decode(privateKeyBytes)
|
||||
switch {
|
||||
case block != nil:
|
||||
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(errParsingPrivateKey, err)
|
||||
}
|
||||
var ok bool
|
||||
privateKey, ok = parsedKey.(ed25519.PrivateKey)
|
||||
if !ok {
|
||||
return nil, nil, errInvalidKeyType
|
||||
}
|
||||
default:
|
||||
if len(privateKeyBytes) != ed25519.PrivateKeySize {
|
||||
return nil, nil, errInvalidKeySize
|
||||
}
|
||||
privateKey = ed25519.PrivateKey(privateKeyBytes)
|
||||
}
|
||||
|
||||
publicKey := privateKey.Public().(ed25519.PublicKey)
|
||||
|
||||
privateJwk, err := jwk.FromRaw(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := privateJwk.Set(jwk.KeyIDKey, kid); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
publicJwk, err := jwk.FromRaw(publicKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := publicJwk.Set(jwk.KeyIDKey, kid); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return privateJwk, publicJwk, nil
|
||||
}
|
||||
|
||||
func keyIDFromPath(path string) string {
|
||||
base := filepath.Base(path)
|
||||
ext := filepath.Ext(base)
|
||||
return strings.TrimSuffix(base, ext)
|
||||
}
|
||||
@@ -0,0 +1,418 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package asymmetric_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/auth/tokenizer/asymmetric"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockIDProvider struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (m *mockIDProvider) ID() (string, error) {
|
||||
return m.id, nil
|
||||
}
|
||||
|
||||
func newTestLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
}
|
||||
|
||||
func TestNewKeyManager(t *testing.T) {
|
||||
idProvider := &mockIDProvider{id: "unused"}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
keyPath := filepath.Join(tmpDir, "private.key")
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pkcs8Key,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
setupKey func() string
|
||||
expectErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid PEM key",
|
||||
setupKey: func() string {
|
||||
err := os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
|
||||
require.NoError(t, err)
|
||||
return keyPath
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid raw key",
|
||||
setupKey: func() string {
|
||||
rawKeyPath := filepath.Join(tmpDir, "raw_private.key")
|
||||
err := os.WriteFile(rawKeyPath, privateKey, 0o600)
|
||||
require.NoError(t, err)
|
||||
return rawKeyPath
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent key file",
|
||||
setupKey: func() string {
|
||||
return filepath.Join(tmpDir, "nonexistent.key")
|
||||
},
|
||||
expectErr: true,
|
||||
errContains: "failed to load private key",
|
||||
},
|
||||
{
|
||||
name: "invalid key size",
|
||||
setupKey: func() string {
|
||||
invalidPath := filepath.Join(tmpDir, "invalid.key")
|
||||
err := os.WriteFile(invalidPath, []byte("invalid"), 0o600)
|
||||
require.NoError(t, err)
|
||||
return invalidPath
|
||||
},
|
||||
expectErr: true,
|
||||
errContains: "invalid ED25519 key size",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
path := tc.setupKey()
|
||||
|
||||
km, err := asymmetric.NewTokenizer(path, "", idProvider, newTestLogger())
|
||||
|
||||
if tc.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tc.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tc.errContains)
|
||||
}
|
||||
assert.Nil(t, km)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, km)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSign(t *testing.T) {
|
||||
idProvider := &mockIDProvider{id: "unused"}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
keyPath := filepath.Join(tmpDir, "private.key")
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pkcs8Key,
|
||||
}
|
||||
|
||||
err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
key auth.Key
|
||||
}{
|
||||
{
|
||||
name: "sign valid key with all fields",
|
||||
key: auth.Key{
|
||||
ID: "key-id",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Subject: "user-id",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
Verified: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sign key without subject",
|
||||
key: auth.Key{
|
||||
ID: "key-id",
|
||||
Type: auth.APIKey,
|
||||
Issuer: "supermq.auth",
|
||||
Role: auth.AdminRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).UTC(),
|
||||
Verified: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sign key without ID",
|
||||
key: auth.Key{
|
||||
Type: auth.AccessKey,
|
||||
Subject: "user-id",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token, err := km.Issue(tc.key)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
parts := splitJWT(token)
|
||||
assert.Equal(t, 3, len(parts), "JWT should have 3 parts")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
idProvider := &mockIDProvider{id: "unused"}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
keyPath := filepath.Join(tmpDir, "private.key")
|
||||
kid := "private"
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pkcs8Key,
|
||||
}
|
||||
|
||||
err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
|
||||
require.NoError(t, err)
|
||||
|
||||
validKey := auth.Key{
|
||||
ID: "key-id",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Subject: "user-id",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
Verified: true,
|
||||
}
|
||||
|
||||
validToken, err := km.Issue(validKey)
|
||||
require.NoError(t, err, "Signing a valid token should succeed")
|
||||
|
||||
expiredKey := validKey
|
||||
expiredKey.ExpiresAt = time.Now().Add(-1 * time.Hour).UTC()
|
||||
expiredToken, err := km.Issue(expiredKey)
|
||||
require.NoError(t, err, "Creating an expired token should succeed")
|
||||
|
||||
wrongIssuerKey := validKey
|
||||
wrongIssuerKey.Issuer = "wrong.issuer"
|
||||
|
||||
privateJwk, err := jwk.FromRaw(privateKey)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, privateJwk.Set(jwk.KeyIDKey, kid))
|
||||
|
||||
builder := jwt.NewBuilder()
|
||||
builder.Issuer(wrongIssuerKey.Issuer).
|
||||
Subject(wrongIssuerKey.Subject).
|
||||
IssuedAt(wrongIssuerKey.IssuedAt).
|
||||
Expiration(wrongIssuerKey.ExpiresAt).
|
||||
JwtID(wrongIssuerKey.ID).
|
||||
Claim("type", wrongIssuerKey.Type).
|
||||
Claim("role", wrongIssuerKey.Role).
|
||||
Claim("verified", wrongIssuerKey.Verified)
|
||||
|
||||
wrongIssuerJWT, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongIssuerTokenBytes, err := jwt.Sign(wrongIssuerJWT, jwt.WithKey(jwa.EdDSA, privateJwk))
|
||||
require.NoError(t, err)
|
||||
wrongIssuerToken := string(wrongIssuerTokenBytes)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
token string
|
||||
expectErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "verify valid token",
|
||||
token: validToken,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "verify expired token",
|
||||
token: expiredToken,
|
||||
expectErr: true,
|
||||
errContains: "exp",
|
||||
},
|
||||
{
|
||||
name: "verify token with wrong issuer",
|
||||
token: wrongIssuerToken,
|
||||
expectErr: true,
|
||||
errContains: "invalid token issuer",
|
||||
},
|
||||
{
|
||||
name: "verify malformed token",
|
||||
token: "not.a.valid.jwt",
|
||||
expectErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
key, err := km.Parse(context.Background(), tc.token)
|
||||
|
||||
if tc.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tc.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tc.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, validKey.Subject, key.Subject)
|
||||
assert.Equal(t, validKey.Type, key.Type)
|
||||
assert.Equal(t, validKey.Role, key.Role)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeys(t *testing.T) {
|
||||
idProvider := &mockIDProvider{id: "unused"}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
keyPath := filepath.Join(tmpDir, "private.key")
|
||||
kid := "private"
|
||||
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pkcs8Key,
|
||||
}
|
||||
|
||||
err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := km.RetrieveJWKS()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
|
||||
key := keys[0]
|
||||
assert.Equal(t, kid, key.KeyID)
|
||||
assert.Equal(t, "OKP", key.KeyType)
|
||||
assert.Equal(t, "EdDSA", key.Algorithm)
|
||||
assert.Equal(t, "sig", key.Use)
|
||||
assert.Equal(t, "Ed25519", key.Curve)
|
||||
assert.NotEmpty(t, key.X)
|
||||
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(key.X)
|
||||
assert.NoError(t, err, "The public key should be decoded")
|
||||
assert.Equal(t, publicKey, ed25519.PublicKey(decoded))
|
||||
}
|
||||
|
||||
func TestSignAndVerifyRoundTrip(t *testing.T) {
|
||||
idProvider := &mockIDProvider{id: "unused"}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
keyPath := filepath.Join(tmpDir, "private.key")
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pkcs8Key,
|
||||
}
|
||||
|
||||
err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger())
|
||||
require.NoError(t, err)
|
||||
|
||||
originalKey := auth.Key{
|
||||
ID: "key-123",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Subject: "user-456",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC().Truncate(time.Second),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC().Truncate(time.Second),
|
||||
Verified: true,
|
||||
}
|
||||
|
||||
token, err := km.Issue(originalKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifiedKey, err := km.Parse(context.Background(), token)
|
||||
require.NoError(t, err, "Verification of a valid key should succeed")
|
||||
|
||||
assert.Equal(t, originalKey.ID, verifiedKey.ID)
|
||||
assert.Equal(t, originalKey.Type, verifiedKey.Type)
|
||||
assert.Equal(t, originalKey.Subject, verifiedKey.Subject)
|
||||
assert.Equal(t, originalKey.Role, verifiedKey.Role)
|
||||
assert.Equal(t, originalKey.Verified, verifiedKey.Verified)
|
||||
assert.WithinDuration(t, originalKey.IssuedAt, verifiedKey.IssuedAt, time.Second)
|
||||
assert.WithinDuration(t, originalKey.ExpiresAt, verifiedKey.ExpiresAt, time.Second)
|
||||
}
|
||||
|
||||
func splitJWT(token string) []string {
|
||||
parts := []string{}
|
||||
start := 0
|
||||
for i := 0; i < len(token); i++ {
|
||||
if token[i] == '.' {
|
||||
parts = append(parts, token[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
parts = append(parts, token[start:])
|
||||
return parts
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package symmetric
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
smqjwt "github.com/absmach/supermq/auth/tokenizer/util"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
patPrefix = "pat"
|
||||
)
|
||||
|
||||
var errJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
|
||||
type tokenizer struct {
|
||||
algorithm jwa.KeyAlgorithm
|
||||
secret []byte
|
||||
}
|
||||
|
||||
var _ auth.Tokenizer = (*tokenizer)(nil)
|
||||
|
||||
func NewTokenizer(algorithm string, secret []byte) (auth.Tokenizer, error) {
|
||||
alg := jwa.KeyAlgorithmFrom(algorithm)
|
||||
if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok {
|
||||
return nil, auth.ErrUnsupportedKeyAlgorithm
|
||||
}
|
||||
if len(secret) == 0 {
|
||||
return nil, auth.ErrInvalidSymmetricKey
|
||||
}
|
||||
return &tokenizer{
|
||||
secret: secret,
|
||||
algorithm: alg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (km *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
tkn, err := smqjwt.BuildToken(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(km.algorithm, km.secret))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(signedBytes), nil
|
||||
}
|
||||
|
||||
func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
|
||||
if len(tokenString) >= 3 && tokenString[:3] == patPrefix {
|
||||
return auth.Key{Type: auth.PersonalAccessToken}, nil
|
||||
}
|
||||
|
||||
tkn, err := jwt.Parse(
|
||||
[]byte(tokenString),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(km.algorithm, km.secret),
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Contains(err, errJWTExpiryKey) {
|
||||
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, auth.ErrExpiry)
|
||||
}
|
||||
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
if tkn.Issuer() != smqjwt.IssuerName {
|
||||
return auth.Key{}, smqjwt.ErrInvalidIssuer
|
||||
}
|
||||
|
||||
return smqjwt.ToKey(tkn)
|
||||
}
|
||||
|
||||
func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
|
||||
return nil, auth.ErrPublicKeysNotSupported
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package symmetric_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/auth/tokenizer/symmetric"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTokenizer(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
algorithm string
|
||||
secret []byte
|
||||
expectErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid HS256 algorithm",
|
||||
algorithm: "HS256",
|
||||
secret: []byte("my-secret-key-32-bytes-long!!"),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid HS384 algorithm",
|
||||
algorithm: "HS384",
|
||||
secret: []byte("my-secret-key-48-bytes-long-for-hs384-!!!!!!"),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid HS512 algorithm",
|
||||
algorithm: "HS512",
|
||||
secret: []byte("my-secret-key-64-bytes-long-for-hs512-algorithm-testing!!!!"),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid algorithm",
|
||||
algorithm: "INVALID_ALG",
|
||||
secret: []byte("my-secret-key"),
|
||||
expectErr: true,
|
||||
errContains: "unsupported key algorithm",
|
||||
},
|
||||
{
|
||||
name: "empty secret",
|
||||
algorithm: "HS256",
|
||||
secret: []byte{},
|
||||
expectErr: true,
|
||||
errContains: "invalid symmetric key",
|
||||
},
|
||||
{
|
||||
name: "nil secret",
|
||||
algorithm: "HS256",
|
||||
secret: nil,
|
||||
expectErr: true,
|
||||
errContains: "invalid symmetric key",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
km, err := symmetric.NewTokenizer(tc.algorithm, tc.secret)
|
||||
|
||||
if tc.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tc.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tc.errContains)
|
||||
}
|
||||
assert.Nil(t, km)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, km)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSign(t *testing.T) {
|
||||
secret := []byte("my-super-secret-key-for-testing")
|
||||
|
||||
km, err := symmetric.NewTokenizer("HS256", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
key auth.Key
|
||||
}{
|
||||
{
|
||||
name: "sign valid key with all fields",
|
||||
key: auth.Key{
|
||||
ID: "key-id",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Subject: "user-id",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
Verified: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sign key without subject",
|
||||
key: auth.Key{
|
||||
ID: "key-id",
|
||||
Type: auth.APIKey,
|
||||
Issuer: "supermq.auth",
|
||||
Role: auth.AdminRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour).UTC(),
|
||||
Verified: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sign key without ID",
|
||||
key: auth.Key{
|
||||
Type: auth.AccessKey,
|
||||
Subject: "user-id",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token, err := km.Issue(tc.key)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Verify the token is valid JWT format (3 parts separated by dots)
|
||||
parts := splitJWT(token)
|
||||
assert.Equal(t, 3, len(parts), "JWT should have 3 parts")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
secret := []byte("my-super-secret-key-for-testing")
|
||||
|
||||
km, err := symmetric.NewTokenizer("HS256", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
validKey := auth.Key{
|
||||
ID: "key-id",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Subject: "user-id",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
Verified: true,
|
||||
}
|
||||
|
||||
validToken, err := km.Issue(validKey)
|
||||
require.NoError(t, err, "Signing valid token should succeed")
|
||||
|
||||
expiredKey := validKey
|
||||
expiredKey.ExpiresAt = time.Now().Add(-1 * time.Hour).UTC()
|
||||
expiredToken, err := km.Issue(expiredKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongIssuerKey := validKey
|
||||
wrongIssuerKey.Issuer = "wrong.issuer"
|
||||
|
||||
builder := jwt.NewBuilder()
|
||||
builder.Issuer(wrongIssuerKey.Issuer).
|
||||
Subject(wrongIssuerKey.Subject).
|
||||
IssuedAt(wrongIssuerKey.IssuedAt).
|
||||
Expiration(wrongIssuerKey.ExpiresAt).
|
||||
JwtID(wrongIssuerKey.ID).
|
||||
Claim("type", wrongIssuerKey.Type).
|
||||
Claim("role", wrongIssuerKey.Role).
|
||||
Claim("verified", wrongIssuerKey.Verified)
|
||||
|
||||
wrongIssuerJWT, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongIssuerTokenBytes, err := jwt.Sign(wrongIssuerJWT, jwt.WithKey(jwa.HS256, secret))
|
||||
require.NoError(t, err)
|
||||
wrongIssuerToken := string(wrongIssuerTokenBytes)
|
||||
|
||||
wrongSecretKM, err := symmetric.NewTokenizer("HS256", []byte("different-secret-key-here"))
|
||||
require.NoError(t, err)
|
||||
wrongSecretToken, err := wrongSecretKM.Issue(validKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
token string
|
||||
expectErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "verify valid token",
|
||||
token: validToken,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "verify expired token",
|
||||
token: expiredToken,
|
||||
expectErr: true,
|
||||
errContains: "exp",
|
||||
},
|
||||
{
|
||||
name: "verify token with wrong issuer",
|
||||
token: wrongIssuerToken,
|
||||
expectErr: true,
|
||||
errContains: "invalid token issuer",
|
||||
},
|
||||
{
|
||||
name: "verify token with wrong secret",
|
||||
token: wrongSecretToken,
|
||||
expectErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "verify malformed token",
|
||||
token: "not.a.valid.jwt",
|
||||
expectErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
key, err := km.Parse(context.Background(), tc.token)
|
||||
|
||||
if tc.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tc.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tc.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, validKey.Subject, key.Subject)
|
||||
assert.Equal(t, validKey.Type, key.Type)
|
||||
assert.Equal(t, validKey.Role, key.Role)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeys(t *testing.T) {
|
||||
secret := []byte("my-super-secret-key-for-testing")
|
||||
|
||||
km, err := symmetric.NewTokenizer("HS256", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := km.RetrieveJWKS()
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, auth.ErrPublicKeysNotSupported, err)
|
||||
assert.Nil(t, keys)
|
||||
}
|
||||
|
||||
func TestSignAndVerifyRoundTrip(t *testing.T) {
|
||||
algorithms := []string{"HS256", "HS384", "HS512"}
|
||||
|
||||
for _, alg := range algorithms {
|
||||
t.Run(alg, func(t *testing.T) {
|
||||
secret := []byte("my-super-secret-key-for-testing-" + alg)
|
||||
|
||||
km, err := symmetric.NewTokenizer(alg, secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
originalKey := auth.Key{
|
||||
ID: "key-123",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Subject: "user-456",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC().Truncate(time.Second),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC().Truncate(time.Second),
|
||||
Verified: true,
|
||||
}
|
||||
|
||||
token, err := km.Issue(originalKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifiedKey, err := km.Parse(context.Background(), token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalKey.ID, verifiedKey.ID)
|
||||
assert.Equal(t, originalKey.Type, verifiedKey.Type)
|
||||
assert.Equal(t, originalKey.Subject, verifiedKey.Subject)
|
||||
assert.Equal(t, originalKey.Role, verifiedKey.Role)
|
||||
assert.Equal(t, originalKey.Verified, verifiedKey.Verified)
|
||||
assert.WithinDuration(t, originalKey.IssuedAt, verifiedKey.IssuedAt, time.Second)
|
||||
assert.WithinDuration(t, originalKey.ExpiresAt, verifiedKey.ExpiresAt, time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentAlgorithms(t *testing.T) {
|
||||
secret := []byte("my-super-secret-key-for-testing-algorithms")
|
||||
|
||||
key := auth.Key{
|
||||
ID: "key-id",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Subject: "user-id",
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour).UTC(),
|
||||
Verified: true,
|
||||
}
|
||||
|
||||
km256, err := symmetric.NewTokenizer("HS256", secret)
|
||||
require.NoError(t, err)
|
||||
token256, err := km256.Issue(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
km384, err := symmetric.NewTokenizer("HS384", secret)
|
||||
require.NoError(t, err)
|
||||
token384, err := km384.Issue(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
km512, err := symmetric.NewTokenizer("HS512", secret)
|
||||
require.NoError(t, err)
|
||||
token512, err := km512.Issue(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, token256, token384)
|
||||
assert.NotEqual(t, token256, token512)
|
||||
assert.NotEqual(t, token384, token512)
|
||||
|
||||
_, err = km256.Parse(context.Background(), token256)
|
||||
assert.NoError(t, err, "verification of km256 token should pass with km256 verifier")
|
||||
|
||||
_, err = km384.Parse(context.Background(), token384)
|
||||
assert.NoError(t, err, "verification of km384 token should pass with km384 verifier")
|
||||
|
||||
_, err = km512.Parse(context.Background(), token512)
|
||||
assert.NoError(t, err, "verification of km512 token should pass with km512 verifier")
|
||||
|
||||
_, err = km384.Parse(context.Background(), token256)
|
||||
assert.Error(t, err, "Cross verification should fail")
|
||||
|
||||
_, err = km512.Parse(context.Background(), token256)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func splitJWT(token string) []string {
|
||||
parts := []string{}
|
||||
start := 0
|
||||
for i := 0; i < len(token); i++ {
|
||||
if token[i] == '.' {
|
||||
parts = append(parts, token[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
parts = append(parts, token[start:])
|
||||
return parts
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidIssuer represents an invalid token issuer value.
|
||||
ErrInvalidIssuer = errors.New("invalid token issuer value")
|
||||
|
||||
// ErrJSONHandle indicates an error in handling JSON.
|
||||
ErrJSONHandle = errors.New("failed to perform operation JSON")
|
||||
|
||||
errInvalidType = errors.New("invalid token type")
|
||||
errInvalidRole = errors.New("invalid role")
|
||||
errInvalidVerified = errors.New("invalid verified")
|
||||
)
|
||||
|
||||
const (
|
||||
IssuerName = "supermq.auth"
|
||||
TokenType = "type"
|
||||
RoleField = "role"
|
||||
VerifiedField = "verified"
|
||||
)
|
||||
|
||||
// ToKey converts a JWT token to an auth.Key by extracting claims.
|
||||
func ToKey(tkn jwt.Token) (auth.Key, error) {
|
||||
data, err := json.Marshal(tkn.PrivateClaims())
|
||||
if err != nil {
|
||||
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
|
||||
}
|
||||
var key auth.Key
|
||||
if err := json.Unmarshal(data, &key); err != nil {
|
||||
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
|
||||
}
|
||||
|
||||
tType, ok := tkn.Get(TokenType)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
kType, ok := tType.(float64)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
kt := auth.KeyType(kType)
|
||||
if !kt.Validate() {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
|
||||
tRole, ok := tkn.Get(RoleField)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
kRole, ok := tRole.(float64)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
|
||||
tVerified, ok := tkn.Get(VerifiedField)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidVerified
|
||||
}
|
||||
kVerified, ok := tVerified.(bool)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidVerified
|
||||
}
|
||||
|
||||
kr := auth.Role(kRole)
|
||||
if !kr.Validate() {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
|
||||
key.ID = tkn.JwtID()
|
||||
key.Type = auth.KeyType(kType)
|
||||
key.Role = auth.Role(kRole)
|
||||
key.Issuer = tkn.Issuer()
|
||||
key.Subject = tkn.Subject()
|
||||
key.IssuedAt = tkn.IssuedAt()
|
||||
key.ExpiresAt = tkn.Expiration()
|
||||
key.Verified = kVerified
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func BuildToken(key auth.Key) (jwt.Token, error) {
|
||||
builder := jwt.NewBuilder()
|
||||
builder.
|
||||
Issuer(IssuerName).
|
||||
IssuedAt(key.IssuedAt).
|
||||
Claim(TokenType, key.Type).
|
||||
Expiration(key.ExpiresAt).
|
||||
Claim(RoleField, key.Role).
|
||||
Claim(VerifiedField, key.Verified)
|
||||
|
||||
if key.Subject != "" {
|
||||
builder.Subject(key.Subject)
|
||||
}
|
||||
if key.ID != "" {
|
||||
builder.JwtID(key.ID)
|
||||
}
|
||||
|
||||
return builder.Build()
|
||||
}
|
||||
@@ -136,8 +136,7 @@ type Service interface {
|
||||
// ChannelRepository specifies a channel persistence API.
|
||||
type Repository interface {
|
||||
// Save persists multiple channels. Channels are saved using a transaction. If one channel
|
||||
// fails then none will be saved. Successful operation is indicated by non-nil
|
||||
// error response.
|
||||
// fails then none will be saved. Successful operation is indicated by non-nil error response.
|
||||
Save(ctx context.Context, chs ...Channel) ([]Channel, error)
|
||||
|
||||
// Update performs an update to the existing channel.
|
||||
|
||||
+42
-11
@@ -22,11 +22,10 @@ import (
|
||||
httpapi "github.com/absmach/supermq/auth/api/http"
|
||||
"github.com/absmach/supermq/auth/cache"
|
||||
"github.com/absmach/supermq/auth/hasher"
|
||||
"github.com/absmach/supermq/auth/jwt"
|
||||
"github.com/absmach/supermq/auth/keymanager/asymmetric"
|
||||
"github.com/absmach/supermq/auth/keymanager/symmetric"
|
||||
"github.com/absmach/supermq/auth/middleware"
|
||||
apostgres "github.com/absmach/supermq/auth/postgres"
|
||||
"github.com/absmach/supermq/auth/tokenizer/asymmetric"
|
||||
"github.com/absmach/supermq/auth/tokenizer/symmetric"
|
||||
redisclient "github.com/absmach/supermq/internal/clients/redis"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
"github.com/absmach/supermq/pkg/jaeger"
|
||||
@@ -69,7 +68,8 @@ type config struct {
|
||||
AccessDuration time.Duration `env:"SMQ_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"`
|
||||
RefreshDuration time.Duration `env:"SMQ_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"`
|
||||
KeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"EdDSA"`
|
||||
PrivateKeyPath string `env:"SMQ_AUTH_KEYS_PRIVATE_KEY_PATH" envDefault:"./ssl/keys/private.pem"`
|
||||
ActiveKeyPath string `env:"SMQ_AUTH_KEYS_ACTIVE_KEY_PATH" envDefault:"./keys/active.key"`
|
||||
RetiringKeyPath string `env:"SMQ_AUTH_KEYS_RETIRING_KEY_PATH" envDefault:""`
|
||||
InvitationDuration time.Duration `env:"SMQ_AUTH_INVITATION_DURATION" envDefault:"168h"`
|
||||
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
@@ -159,17 +159,23 @@ func main() {
|
||||
|
||||
idProvider := uuid.New()
|
||||
|
||||
var keyManager auth.KeyManager
|
||||
if err := validateKeyConfig(isSymmetric, cfg, logger); err != nil {
|
||||
logger.Error(fmt.Sprintf("invalid key configuration: %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
var tokenizer auth.Tokenizer
|
||||
switch {
|
||||
case isSymmetric:
|
||||
keyManager, err = symmetric.NewKeyManager(cfg.KeyAlgorithm, []byte(cfg.SecretKey))
|
||||
tokenizer, err = symmetric.NewTokenizer(cfg.KeyAlgorithm, []byte(cfg.SecretKey))
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create symmetric key manager: %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
default:
|
||||
keyManager, err = asymmetric.NewKeyManager(cfg.PrivateKeyPath, idProvider)
|
||||
tokenizer, err = asymmetric.NewTokenizer(cfg.ActiveKeyPath, cfg.RetiringKeyPath, idProvider, logger)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create asymmetric key manager: %s", err.Error()))
|
||||
exitCode = 1
|
||||
@@ -177,7 +183,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration, keyManager, idProvider)
|
||||
svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration, tokenizer, idProvider)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create service : %s\n", err.Error()))
|
||||
exitCode = 1
|
||||
@@ -258,7 +264,34 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch
|
||||
return nil
|
||||
}
|
||||
|
||||
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, keyManager auth.KeyManager, idProvider supermq.IDProvider) (auth.Service, error) {
|
||||
func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error {
|
||||
if isSymmetric {
|
||||
if cfg.SecretKey == "secret" {
|
||||
return fmt.Errorf("default secret key is insecure - please set SMQ_AUTH_SECRET_KEY environment variable")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate active key path
|
||||
_, err := os.Stat(cfg.ActiveKeyPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("active key file not found: %s - please set SMQ_AUTH_KEYS_ACTIVE_KEY_PATH", cfg.ActiveKeyPath)
|
||||
}
|
||||
return fmt.Errorf("failed to access active key file: %w", err)
|
||||
}
|
||||
|
||||
// Retiring key is optional - only validate if path is provided
|
||||
if cfg.RetiringKeyPath != "" {
|
||||
if _, err := os.Stat(cfg.RetiringKeyPath); err != nil {
|
||||
l.Warn("retiring key path provided but file not accessible", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, tokenizer auth.Tokenizer, idProvider supermq.IDProvider) (auth.Service, error) {
|
||||
cache := cache.NewPatsCache(cacheClient, keyDuration)
|
||||
|
||||
database := pgclient.NewDatabase(db, dbConfig, tracer)
|
||||
@@ -269,8 +302,6 @@ func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.
|
||||
pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger)
|
||||
pService := spicedb.NewPolicyService(spicedbClient, logger)
|
||||
|
||||
tokenizer := jwt.New(keyManager)
|
||||
|
||||
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
|
||||
svc = middleware.NewLogging(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics("auth", "api")
|
||||
|
||||
+2
-1
@@ -101,7 +101,8 @@ SMQ_AUTH_DB_SSL_ROOT_CERT=
|
||||
SMQ_AUTH_ACCESS_TOKEN_DURATION="1h"
|
||||
SMQ_AUTH_REFRESH_TOKEN_DURATION="24h"
|
||||
SMQ_AUTH_KEYS_ALGORITHM="EdDSA"
|
||||
SMQ_AUTH_KEYS_PRIVATE_KEY_PATH="./ssl/keys/private.key"
|
||||
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key"
|
||||
SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key"
|
||||
SMQ_AUTH_INVITATION_DURATION="168h"
|
||||
SMQ_AUTH_ADAPTER_INSTANCE_ID=
|
||||
SMQ_AUTH_CACHE_URL=redis://auth-redis:${SMQ_REDIS_TCP_PORT}/0
|
||||
|
||||
@@ -121,7 +121,8 @@ services:
|
||||
SMQ_AUTH_ACCESS_TOKEN_DURATION: ${SMQ_AUTH_ACCESS_TOKEN_DURATION}
|
||||
SMQ_AUTH_REFRESH_TOKEN_DURATION: ${SMQ_AUTH_REFRESH_TOKEN_DURATION}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_AUTH_KEYS_PRIVATE_KEY_PATH: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:+/keys/private.key}
|
||||
SMQ_AUTH_KEYS_ACTIVE_KEY_PATH: ${SMQ_AUTH_KEYS_ACTIVE_KEY_PATH:+/keys/active.key}
|
||||
SMQ_AUTH_KEYS_RETIRING_KEY_PATH: ${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:+/keys/retiring.key}
|
||||
## Compose supports parameter expansion in environment,
|
||||
## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty
|
||||
## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default
|
||||
@@ -152,11 +153,16 @@ services:
|
||||
volumes:
|
||||
- ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE}
|
||||
- supermq-pat-db-volume:/supermq-data
|
||||
- supermq-auth-keys-volume:/keys
|
||||
# Auth private key file
|
||||
# Auth active private key file
|
||||
- type: bind
|
||||
source: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:-ssl/certs/dummy/private_key}
|
||||
target: /keys/private.key
|
||||
source: ${SMQ_AUTH_KEYS_ACTIVE_KEY_PATH}
|
||||
target: /keys/active.key
|
||||
read_only: true
|
||||
# Auth retiring private key file (optional, for key rotation)
|
||||
- type: bind
|
||||
source: ${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:-ssl/certs/dummy/retiring_key}
|
||||
target: /keys/retiring${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:+.key}
|
||||
read_only: true
|
||||
bind:
|
||||
create_host_path: true
|
||||
# Auth gRPC mTLS server certificates
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIE9Qu5lN6KOfdO14XJUClM1UPrqT55BczLMcRuSG7Ziy
|
||||
-----END PRIVATE KEY-----
|
||||
+78
-34
@@ -5,6 +5,7 @@ package jwks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
grpcAuthV1 "github.com/absmach/supermq/api/grpc/auth/v1"
|
||||
smqauth "github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/auth/api/grpc/auth"
|
||||
smqjwt "github.com/absmach/supermq/auth/jwt"
|
||||
smqjwt "github.com/absmach/supermq/auth/tokenizer/util"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -26,8 +27,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
issuerName = "supermq.auth"
|
||||
cacheDuration = 5 * time.Minute
|
||||
issuerName = "supermq.auth"
|
||||
acceptHeader = "application/json"
|
||||
|
||||
fetchKeyDeadline = 10 * time.Second
|
||||
cacheDuration = 5 * time.Minute
|
||||
|
||||
errorBodyBytes = 1024
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -39,12 +45,6 @@ var (
|
||||
errInvalidIssuer = errors.New("invalid token issuer value")
|
||||
// ErrValidateJWTToken indicates a failure to validate JWT token.
|
||||
errValidateJWTToken = errors.New("failed to validate jwt token")
|
||||
|
||||
jwksCache = struct {
|
||||
sync.RWMutex
|
||||
jwks jwk.Set
|
||||
cachedAt time.Time
|
||||
}{}
|
||||
)
|
||||
|
||||
var _ authn.Authentication = (*authentication)(nil)
|
||||
@@ -52,6 +52,14 @@ var _ authn.Authentication = (*authentication)(nil)
|
||||
type authentication struct {
|
||||
jwksURL string
|
||||
authSvcClient grpcAuthV1.AuthServiceClient
|
||||
httpClient *http.Client
|
||||
cache *jwksCache
|
||||
}
|
||||
|
||||
type jwksCache struct {
|
||||
mu sync.RWMutex
|
||||
jwks jwk.Set
|
||||
cachedAt time.Time
|
||||
}
|
||||
|
||||
func NewAuthentication(ctx context.Context, jwksURL string, cfg grpcclient.Config) (authn.Authentication, grpcclient.Handler, error) {
|
||||
@@ -69,9 +77,13 @@ func NewAuthentication(ctx context.Context, jwksURL string, cfg grpcclient.Confi
|
||||
}
|
||||
authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout)
|
||||
|
||||
httpClient := &http.Client{}
|
||||
|
||||
return authentication{
|
||||
jwksURL: jwksURL,
|
||||
authSvcClient: authSvcClient,
|
||||
httpClient: httpClient,
|
||||
cache: &jwksCache{},
|
||||
}, client, nil
|
||||
}
|
||||
|
||||
@@ -84,14 +96,26 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
|
||||
return authn.Session{Type: authn.PersonalAccessToken, PatID: res.GetId(), UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
|
||||
}
|
||||
|
||||
jwks, err := a.fetchJWKS(ctx)
|
||||
jwks, err := a.fetchJWKS(ctx, false)
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
tkn, err := validateToken(token, jwks)
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
// If signature verification failed, try force with refresh JWKS (key rotation scenario)
|
||||
if isSignatureError(err) {
|
||||
jwks, fetchErr := a.fetchJWKS(ctx, true)
|
||||
if fetchErr == nil {
|
||||
tkn, err = validateToken(token, jwks)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
}
|
||||
|
||||
key, err := smqjwt.ToKey(tkn)
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
@@ -105,45 +129,63 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a authentication) fetchJWKS(ctx context.Context) (jwk.Set, error) {
|
||||
jwksCache.RLock()
|
||||
if time.Since(jwksCache.cachedAt) < cacheDuration && jwksCache.jwks.Len() > 0 {
|
||||
cached := jwksCache.jwks
|
||||
jwksCache.RUnlock()
|
||||
return cached, nil
|
||||
}
|
||||
jwksCache.RUnlock()
|
||||
func isSignatureError(err error) bool {
|
||||
return !errors.Contains(err, errJWTExpiryKey) &&
|
||||
!errors.Contains(err, errInvalidIssuer) &&
|
||||
!errors.Contains(err, smqauth.ErrExpiry)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", a.jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (a authentication) fetchJWKS(ctx context.Context, forceRefresh bool) (jwk.Set, error) {
|
||||
if !forceRefresh {
|
||||
a.cache.mu.RLock()
|
||||
if time.Since(a.cache.cachedAt) < cacheDuration && a.cache.jwks.Len() > 0 {
|
||||
cached := a.cache.jwks
|
||||
a.cache.mu.RUnlock()
|
||||
return cached, nil
|
||||
}
|
||||
a.cache.mu.RUnlock()
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
fetchCtx := ctx
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
fetchCtx, cancel = context.WithTimeout(ctx, fetchKeyDeadline)
|
||||
defer cancel()
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
|
||||
// Fetch fresh JWKS from auth service
|
||||
req, err := http.NewRequestWithContext(fetchCtx, http.MethodGet, a.jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(errFetchJWKS, err)
|
||||
}
|
||||
req.Header.Set("Accept", acceptHeader)
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errFetchJWKS, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errFetchJWKS
|
||||
// Read error body for better diagnostics
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, errorBodyBytes))
|
||||
return nil, errors.Wrap(errFetchJWKS, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(errFetchJWKS, err)
|
||||
}
|
||||
|
||||
set, err := jwk.Parse(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(errFetchJWKS, err)
|
||||
}
|
||||
jwksCache.Lock()
|
||||
jwksCache.jwks = set
|
||||
jwksCache.cachedAt = time.Now()
|
||||
jwksCache.Unlock()
|
||||
|
||||
a.cache.mu.Lock()
|
||||
a.cache.jwks = set
|
||||
a.cache.cachedAt = time.Now()
|
||||
a.cache.mu.Unlock()
|
||||
|
||||
return set, nil
|
||||
}
|
||||
@@ -160,6 +202,8 @@ func validateToken(token string, jwks jwk.Set) (jwt.Token, error) {
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate issuer
|
||||
validator := jwt.ValidatorFunc(func(_ context.Context, t jwt.Token) jwt.ValidationError {
|
||||
if t.Issuer() != issuerName {
|
||||
return jwt.NewValidationError(errInvalidIssuer)
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jwks_test
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
smqjwt "github.com/absmach/supermq/auth/tokenizer/util"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
validIssuer = "supermq.auth"
|
||||
invalidIssuer = "invalid.issuer"
|
||||
userID = "user123"
|
||||
)
|
||||
|
||||
func TestValidateToken(t *testing.T) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
kid := "test-key-id"
|
||||
|
||||
privateJwk, err := jwk.FromRaw(privateKey)
|
||||
require.NoError(t, err, "Parsing JWKS private key should work")
|
||||
require.NoError(t, privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, privateJwk.Set(jwk.KeyIDKey, kid))
|
||||
|
||||
publicJwk, err := jwk.FromRaw(publicKey)
|
||||
require.NoError(t, err, "Parsing JWKS public key should work")
|
||||
require.NoError(t, publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, publicJwk.Set(jwk.KeyIDKey, kid))
|
||||
|
||||
jwksSet := jwk.NewSet()
|
||||
require.NoError(t, jwksSet.AddKey(publicJwk), "Creation of JWKS should succeed")
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
issuer string
|
||||
expiry time.Time
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
issuer: validIssuer,
|
||||
expiry: time.Now().Add(1 * time.Hour),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
issuer: validIssuer,
|
||||
expiry: time.Now().Add(-1 * time.Hour),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer",
|
||||
issuer: invalidIssuer,
|
||||
expiry: time.Now().Add(1 * time.Hour),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
builder := jwt.NewBuilder()
|
||||
builder.Issuer(tc.issuer)
|
||||
builder.IssuedAt(time.Now())
|
||||
builder.Expiration(tc.expiry)
|
||||
builder.Subject(userID)
|
||||
builder.Claim(smqjwt.TokenType, auth.AccessKey)
|
||||
builder.Claim(smqjwt.RoleField, auth.UserRole)
|
||||
builder.Claim(smqjwt.VerifiedField, true)
|
||||
|
||||
token, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk))
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedToken, err := jwt.Parse(
|
||||
signedToken,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(jwksSet),
|
||||
)
|
||||
|
||||
if tc.expectErr {
|
||||
// For expired token, parsing will fail
|
||||
// For invalid issuer, parsing succeeds but we need to check issuer
|
||||
if tc.issuer != validIssuer && err == nil {
|
||||
if parsedToken.Issuer() != validIssuer {
|
||||
assert.NotEqual(t, validIssuer, parsedToken.Issuer())
|
||||
} else {
|
||||
assert.Fail(t, "Expected invalid issuer error")
|
||||
}
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, parsedToken)
|
||||
assert.Equal(t, tc.issuer, parsedToken.Issuer())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiKeyJWKS(t *testing.T) {
|
||||
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pub2, priv2, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeKID := "key-active"
|
||||
retiringKID := "key-retiring"
|
||||
|
||||
privateJwk1, err := jwk.FromRaw(priv1)
|
||||
require.NoError(t, err, "Parsing JWKS private key 1 should work")
|
||||
require.NoError(t, privateJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, privateJwk1.Set(jwk.KeyIDKey, activeKID))
|
||||
|
||||
publicJwk1, err := jwk.FromRaw(pub1)
|
||||
require.NoError(t, err, "Parsing JWKS public key 1 should work")
|
||||
require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, activeKID))
|
||||
|
||||
privateJwk2, err := jwk.FromRaw(priv2)
|
||||
require.NoError(t, err, "Parsing JWKS private key 2 should work")
|
||||
require.NoError(t, privateJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, privateJwk2.Set(jwk.KeyIDKey, retiringKID))
|
||||
|
||||
publicJwk2, err := jwk.FromRaw(pub2)
|
||||
require.NoError(t, err, "Parsing JWKS public key 2 should work")
|
||||
require.NoError(t, publicJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, publicJwk2.Set(jwk.KeyIDKey, retiringKID))
|
||||
|
||||
// Create JWKS set with both keys (simulating rotation period)
|
||||
jwksSet := jwk.NewSet()
|
||||
require.NoError(t, jwksSet.AddKey(publicJwk1))
|
||||
require.NoError(t, jwksSet.AddKey(publicJwk2))
|
||||
|
||||
assert.Equal(t, 2, jwksSet.Len(), "JWKS should contain both keys")
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
privateKey jwk.Key
|
||||
kid string
|
||||
issuer string
|
||||
expiry time.Time
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "token signed with active key",
|
||||
privateKey: privateJwk1,
|
||||
kid: activeKID,
|
||||
issuer: validIssuer,
|
||||
expiry: time.Now().Add(1 * time.Hour),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "token signed with retiring key",
|
||||
privateKey: privateJwk2,
|
||||
kid: retiringKID,
|
||||
issuer: validIssuer,
|
||||
expiry: time.Now().Add(1 * time.Hour),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "expired token with active key",
|
||||
privateKey: privateJwk1,
|
||||
kid: activeKID,
|
||||
issuer: validIssuer,
|
||||
expiry: time.Now().Add(-1 * time.Hour),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired token with retiring key",
|
||||
privateKey: privateJwk2,
|
||||
kid: retiringKID,
|
||||
issuer: validIssuer,
|
||||
expiry: time.Now().Add(-1 * time.Hour),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
builder := jwt.NewBuilder()
|
||||
builder.Issuer(tc.issuer)
|
||||
builder.IssuedAt(time.Now())
|
||||
builder.Expiration(tc.expiry)
|
||||
builder.Subject(userID)
|
||||
builder.Claim(smqjwt.TokenType, auth.AccessKey)
|
||||
builder.Claim(smqjwt.RoleField, auth.UserRole)
|
||||
builder.Claim(smqjwt.VerifiedField, true)
|
||||
|
||||
token, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, tc.privateKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the token using the multi-key JWKS set
|
||||
parsedToken, err := jwt.Parse(
|
||||
signedToken,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)),
|
||||
)
|
||||
|
||||
if tc.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, parsedToken)
|
||||
assert.Equal(t, tc.issuer, parsedToken.Issuer())
|
||||
assert.Equal(t, userID, parsedToken.Subject())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyIDMatching(t *testing.T) {
|
||||
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err, "Generating key par 1 should succeed")
|
||||
|
||||
pub2, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err, "Generating key par 2 should succeed")
|
||||
|
||||
kid1 := "key-1"
|
||||
kid2 := "key-2"
|
||||
|
||||
privateJwk1, err := jwk.FromRaw(priv1)
|
||||
require.NoError(t, err, "Parsing JWKS private key 1 should work")
|
||||
require.NoError(t, privateJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, privateJwk1.Set(jwk.KeyIDKey, kid1))
|
||||
|
||||
publicJwk1, err := jwk.FromRaw(pub1)
|
||||
require.NoError(t, err, "Parsing JWKS public key 1 should work")
|
||||
require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, kid1))
|
||||
|
||||
publicJwk2, err := jwk.FromRaw(pub2)
|
||||
require.NoError(t, err, "Parsing JWKS public key 2 should work")
|
||||
require.NoError(t, publicJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, publicJwk2.Set(jwk.KeyIDKey, kid2))
|
||||
|
||||
jwksSet := jwk.NewSet()
|
||||
require.NoError(t, jwksSet.AddKey(publicJwk1), "Adding public key 1 to JWKS set should succeed")
|
||||
require.NoError(t, jwksSet.AddKey(publicJwk2), "Adding public key 2 to JWKS set should succeed")
|
||||
|
||||
// Create token signed with key-1
|
||||
builder := jwt.NewBuilder()
|
||||
builder.Issuer(validIssuer)
|
||||
builder.IssuedAt(time.Now())
|
||||
builder.Expiration(time.Now().Add(1 * time.Hour))
|
||||
builder.Subject(userID)
|
||||
|
||||
token, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk1))
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedToken, err := jwt.Parse(
|
||||
signedToken,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, parsedToken)
|
||||
assert.Equal(t, validIssuer, parsedToken.Issuer())
|
||||
|
||||
headers, err := jws.Parse(signedToken)
|
||||
require.NoError(t, err)
|
||||
sigs := headers.Signatures()
|
||||
require.Len(t, sigs, 1, "JWT should have exactly one signature")
|
||||
|
||||
kidValue := sigs[0].ProtectedHeaders().KeyID()
|
||||
assert.Equal(t, kid1, kidValue, "JWT kid header should match signing key")
|
||||
}
|
||||
|
||||
func TestUnknownKeyID(t *testing.T) {
|
||||
pub1, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err, "Generating public key 1 should succeed")
|
||||
|
||||
_, priv2, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err, "Generating private key 2 should succeed")
|
||||
|
||||
kid1 := "known-key"
|
||||
kid2 := "unknown-key"
|
||||
|
||||
publicJwk1, err := jwk.FromRaw(pub1)
|
||||
require.NoError(t, err, "Parsing JWKS public key 1 should work")
|
||||
require.NoError(t, publicJwk1.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, publicJwk1.Set(jwk.KeyIDKey, kid1))
|
||||
|
||||
privateJwk2, err := jwk.FromRaw(priv2)
|
||||
require.NoError(t, err, "Parsing JWKS private key 2 should work")
|
||||
require.NoError(t, privateJwk2.Set(jwk.AlgorithmKey, jwa.EdDSA))
|
||||
require.NoError(t, privateJwk2.Set(jwk.KeyIDKey, kid2))
|
||||
|
||||
jwksSet := jwk.NewSet()
|
||||
require.NoError(t, jwksSet.AddKey(publicJwk1), "Adding public key 1 to set should succeed")
|
||||
|
||||
// Create token signed with unknown key-2
|
||||
builder := jwt.NewBuilder()
|
||||
builder.Issuer(validIssuer)
|
||||
builder.IssuedAt(time.Now())
|
||||
builder.Expiration(time.Now().Add(1 * time.Hour))
|
||||
builder.Subject(userID)
|
||||
|
||||
token, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.EdDSA, privateJwk2))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = jwt.Parse(
|
||||
signedToken,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(jwksSet, jws.WithInferAlgorithmFromKey(true)),
|
||||
)
|
||||
|
||||
assert.Error(t, err, "Should fail when token's kid is not in JWKS")
|
||||
}
|
||||
@@ -58,7 +58,6 @@ packages:
|
||||
Cache:
|
||||
Hasher:
|
||||
KeyRepository:
|
||||
KeyManager:
|
||||
Tokenizer:
|
||||
PATS:
|
||||
PATSRepository:
|
||||
|
||||
Reference in New Issue
Block a user