SMQ-1672 - Add asymmetric key authentication (#3228)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-12-23 23:16:06 +03:00
committed by GitHub
parent 59f8d4e4d7
commit 6a5d28c65a
36 changed files with 2060 additions and 560 deletions
+12
View File
@@ -85,3 +85,15 @@ func revokeEndpoint(svc auth.Service) endpoint.Endpoint {
return revokeKeyRes{}, nil
}
}
func retrieveJWKSEndpoint(svc auth.Service, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate int) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
jwks := svc.RetrieveJWKS()
return retrieveJWKSRes{
Keys: jwks,
CacheMaxAge: jwksCacheMaxAge,
CacheStaleWhileRevalidate: jwksCacheStaleWhileRevalidate,
}, nil
}
}
+104 -93
View File
@@ -4,7 +4,8 @@
package keys_test
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"io"
@@ -17,30 +18,26 @@ import (
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/auth"
httpapi "github.com/absmach/supermq/auth/api/http"
"github.com/absmach/supermq/auth/jwt"
"github.com/absmach/supermq/auth/mocks"
smqlog "github.com/absmach/supermq/logger"
svcerr "github.com/absmach/supermq/pkg/errors/service"
policymocks "github.com/absmach/supermq/pkg/policies/mocks"
"github.com/absmach/supermq/pkg/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
secret = "secret"
contentType = "application/json"
id = "123e4567-e89b-12d3-a456-000000000001"
email = "user@example.com"
loginDuration = 30 * time.Minute
refreshDuration = 24 * time.Hour
invalidDuration = 7 * 24 * time.Hour
accessToken = "valid token"
)
var (
krepo *mocks.KeyRepository
pEvaluator *policymocks.Evaluator
)
var Token = auth.Token{
AccessToken: accessToken,
}
type issueRequest struct {
Duration time.Duration `json:"duration,omitempty"`
@@ -72,22 +69,11 @@ func (tr testRequest) make() (*http.Response, error) {
return tr.client.Do(req)
}
func newService() auth.Service {
krepo = new(mocks.KeyRepository)
pRepo := new(mocks.PATSRepository)
cache := new(mocks.Cache)
hash := new(mocks.Hasher)
idProvider := uuid.NewMock()
pService := new(policymocks.Service)
pEvaluator = new(policymocks.Evaluator)
t := jwt.New([]byte(secret))
func newServer() (*httptest.Server, *mocks.Service) {
svc := new(mocks.Service)
mux := httpapi.MakeHandler(svc, smqlog.NewMock(), "", 900, 60)
return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration)
}
func newServer(svc auth.Service) *httptest.Server {
mux := httpapi.MakeHandler(svc, smqlog.NewMock(), "")
return httptest.NewServer(mux)
return httptest.NewServer(mux), svc
}
func toJSON(data any) string {
@@ -99,13 +85,7 @@ func toJSON(data any) string {
}
func TestIssue(t *testing.T) {
svc := newService()
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
policyCall.Unset()
ts := newServer(svc)
ts, svc := newServer()
defer ts.Close()
client := ts.Client()
@@ -119,6 +99,8 @@ func TestIssue(t *testing.T) {
ct string
token string
status int
svcRes auth.Token
svcErr error
}{
{
desc: "issue login key with empty token",
@@ -131,28 +113,30 @@ func TestIssue(t *testing.T) {
desc: "issue API key",
req: toJSON(ak),
ct: contentType,
token: token.AccessToken,
token: accessToken,
status: http.StatusCreated,
svcRes: Token,
},
{
desc: "issue recovery key",
req: toJSON(rk),
ct: contentType,
token: token.AccessToken,
token: accessToken,
status: http.StatusCreated,
svcRes: Token,
},
{
desc: "issue login key wrong content type",
req: toJSON(lk),
ct: "",
token: token.AccessToken,
token: accessToken,
status: http.StatusUnsupportedMediaType,
},
{
desc: "issue recovery key wrong content type",
req: toJSON(rk),
ct: "",
token: token.AccessToken,
token: accessToken,
status: http.StatusUnsupportedMediaType,
},
{
@@ -160,6 +144,7 @@ func TestIssue(t *testing.T) {
req: toJSON(ak),
ct: contentType,
token: "wrong",
svcErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
},
{
@@ -167,27 +152,28 @@ func TestIssue(t *testing.T) {
req: toJSON(rk),
ct: contentType,
token: "",
svcErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
},
{
desc: "issue key with invalid request",
req: "{",
ct: contentType,
token: token.AccessToken,
token: accessToken,
status: http.StatusBadRequest,
},
{
desc: "issue key with invalid JSON",
req: "{invalid}",
ct: contentType,
token: token.AccessToken,
token: accessToken,
status: http.StatusBadRequest,
},
{
desc: "issue key with invalid JSON content",
req: `{"Type":{"key":"AccessToken"}}`,
ct: contentType,
token: token.AccessToken,
token: accessToken,
status: http.StatusBadRequest,
},
}
@@ -201,30 +187,16 @@ func TestIssue(t *testing.T) {
token: tc.token,
body: strings.NewReader(tc.req),
}
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return("", nil)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
svcCall := svc.On("Issue", mock.Anything, tc.token, mock.Anything).Return(tc.svcRes, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
repoCall.Unset()
policyCall.Unset()
svcCall.Unset()
}
}
func TestRetrieve(t *testing.T) {
svc := newService()
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id}
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil)
k, err := svc.Issue(context.Background(), token.AccessToken, key)
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
repoCall.Unset()
policyCall.Unset()
ts := newServer(svc)
ts, svc := newServer()
defer ts.Close()
client := ts.Client()
@@ -234,12 +206,13 @@ func TestRetrieve(t *testing.T) {
token string
key auth.Key
status int
err error
svcRes auth.Key
svcErr error
}{
{
desc: "retrieve an existing key",
id: k.AccessToken,
token: token.AccessToken,
id: id,
token: accessToken,
key: auth.Key{
Subject: id,
Type: auth.AccessKey,
@@ -247,28 +220,28 @@ func TestRetrieve(t *testing.T) {
ExpiresAt: time.Now().Add(refreshDuration),
},
status: http.StatusOK,
err: nil,
svcErr: nil,
},
{
desc: "retrieve a non-existing key",
id: "non-existing",
token: token.AccessToken,
token: accessToken,
status: http.StatusNotFound,
err: svcerr.ErrNotFound,
svcErr: svcerr.ErrNotFound,
},
{
desc: "retrieve a key with an invalid token",
id: k.AccessToken,
id: accessToken,
token: "wrong",
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
svcErr: svcerr.ErrAuthentication,
},
{
desc: "retrieve a key with an empty token",
token: "",
id: k.AccessToken,
id: accessToken,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
svcErr: svcerr.ErrAuthentication,
},
}
@@ -279,30 +252,16 @@ func TestRetrieve(t *testing.T) {
url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id),
token: tc.token,
}
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
repoCall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(tc.key, tc.err)
svcCall := svc.On("RetrieveKey", mock.Anything, tc.token, tc.id).Return(tc.svcRes, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
repoCall.Unset()
policyCall.Unset()
svcCall.Unset()
}
}
func TestRevoke(t *testing.T) {
svc := newService()
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id}
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil)
k, err := svc.Issue(context.Background(), token.AccessToken, key)
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
repoCall.Unset()
policyCall.Unset()
ts := newServer(svc)
ts, svc := newServer()
defer ts.Close()
client := ts.Client()
@@ -311,29 +270,33 @@ func TestRevoke(t *testing.T) {
id string
token string
status int
svcErr error
}{
{
desc: "revoke an existing key",
id: k.AccessToken,
token: token.AccessToken,
id: id,
token: accessToken,
status: http.StatusNoContent,
},
{
desc: "revoke a non-existing key",
id: "non-existing",
token: token.AccessToken,
status: http.StatusNoContent,
token: accessToken,
svcErr: svcerr.ErrNotFound,
status: http.StatusNotFound,
},
{
desc: "revoke key with invalid token",
id: k.AccessToken,
id: id,
token: "wrong",
svcErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
},
{
desc: "revoke key with empty token",
id: k.AccessToken,
id: id,
token: "",
svcErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
},
}
@@ -345,10 +308,58 @@ func TestRevoke(t *testing.T) {
url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id),
token: tc.token,
}
repoCall := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(nil)
svcCall := svc.On("Revoke", mock.Anything, tc.token, tc.id).Return(tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
repoCall.Unset()
svcCall.Unset()
}
}
func TestRetrieveJWKS(t *testing.T) {
ts, svc := newServer()
defer ts.Close()
client := ts.Client()
cases := []struct {
desc string
svcRes []auth.JWK
status int
}{
{
desc: "retrieve JWKS with keys",
svcRes: []auth.JWK{newJWK(t), newJWK(t)},
status: http.StatusOK,
},
{
desc: "retrieve empty JWKS",
svcRes: []auth.JWK{},
status: http.StatusOK,
},
}
for _, tc := range cases {
req := testRequest{
client: client,
method: http.MethodGet,
url: fmt.Sprintf("%s/keys/.well-known/jwks.json", ts.URL),
}
svcCall := svc.On("RetrieveJWKS").Return(tc.svcRes)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
}
}
func newJWK(t *testing.T) auth.JWK {
pKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err))
jwkKey, err := jwk.FromRaw(&pKey.PublicKey)
require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err))
err = jwkKey.Set(jwk.KeyIDKey, "test-key-id")
require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err))
err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String())
require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err))
return auth.NewJWK(jwkKey)
}
+2
View File
@@ -46,3 +46,5 @@ func (req keyReq) validate() error {
}
return nil
}
type jwksReq struct{}
+39
View File
@@ -4,16 +4,21 @@
package keys
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/absmach/supermq"
"github.com/absmach/supermq/auth"
"github.com/lestrrat-go/jwx/v2/jwk"
)
var (
_ supermq.Response = (*issueKeyRes)(nil)
_ supermq.Response = (*revokeKeyRes)(nil)
_ supermq.Response = (*retrieveKeyRes)(nil)
_ supermq.Response = (*retrieveJWKSRes)(nil)
)
type issueKeyRes struct {
@@ -69,3 +74,37 @@ func (res revokeKeyRes) Headers() map[string]string {
func (res revokeKeyRes) Empty() bool {
return true
}
type retrieveJWKSRes struct {
Keys []auth.JWK `json:"-"`
CacheMaxAge int `json:"-"`
CacheStaleWhileRevalidate int `json:"-"`
}
func (res retrieveJWKSRes) MarshalJSON() ([]byte, error) {
set := jwk.NewSet()
for _, k := range res.Keys {
if err := set.AddKey(k.Key()); err != nil {
return nil, err
}
}
return json.Marshal(set)
}
func (res retrieveJWKSRes) Code() int {
return http.StatusOK
}
func (res retrieveJWKSRes) Headers() map[string]string {
cacheControl := fmt.Sprintf("public, max-age=%d, stale-while-revalidate=%d", res.CacheMaxAge, res.CacheStaleWhileRevalidate)
headers := map[string]string{
"Cache-Control": cacheControl,
}
return headers
}
func (res retrieveJWKSRes) Empty() bool {
return false
}
+13 -1
View File
@@ -21,7 +21,7 @@ import (
const contentType = "application/json"
// MakeHandler returns a HTTP handler for API endpoints.
func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux {
func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate int) *chi.Mux {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
@@ -46,6 +46,13 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux {
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Get("/.well-known/jwks.json", kithttp.NewServer(
retrieveJWKSEndpoint(svc, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate),
decodeJWKSReq,
api.EncodeResponse,
opts...,
).ServeHTTP)
})
return mux
}
@@ -70,3 +77,8 @@ func decodeKeyReq(_ context.Context, r *http.Request) (any, error) {
}
return req, nil
}
func decodeJWKSReq(_ context.Context, _ *http.Request) (any, error) {
req := jwksReq{}
return req, nil
}
+2 -2
View File
@@ -15,10 +15,10 @@ import (
)
// MakeHandler returns a HTTP handler for API endpoints.
func MakeHandler(svc auth.Service, logger *slog.Logger, instanceID string) http.Handler {
func MakeHandler(svc auth.Service, logger *slog.Logger, instanceID string, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate int) http.Handler {
mux := chi.NewRouter()
mux = keys.MakeHandler(svc, mux, logger)
mux = keys.MakeHandler(svc, mux, logger, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate)
mux = pats.MakeHandler(svc, mux, logger)
mux.Get("/health", supermq.Health("auth", instanceID))
+188 -85
View File
@@ -4,61 +4,59 @@
package jwt_test
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
"testing"
"time"
"github.com/absmach/supermq/auth"
authjwt "github.com/absmach/supermq/auth/jwt"
"github.com/absmach/supermq/auth/mocks"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
tokenType = "type"
roleField = "role"
issuerName = "supermq.auth"
secret = "test"
tokenType = "type"
roleField = "role"
VerifiedField = "verified"
issuerName = "supermq.auth"
)
var reposecret = []byte("test")
func newToken(issuerName string, key auth.Key) string {
builder := jwt.NewBuilder()
builder.
Issuer(issuerName).
IssuedAt(key.IssuedAt).
Claim(tokenType, "r").
Expiration(key.ExpiresAt)
builder.Claim(roleField, key.Role)
if key.Subject != "" {
builder.Subject(key.Subject)
}
if key.ID != "" {
builder.JwtID(key.ID)
}
tkn, _ := builder.Build()
tokn, _ := jwt.Sign(tkn, jwt.WithKey(jwa.HS512, reposecret))
return string(tokn)
}
var (
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
keyManager = new(mocks.KeyManager)
)
func TestIssue(t *testing.T) {
tokenizer := authjwt.New([]byte(secret))
tokenizer := authjwt.New(keyManager)
validKey := key()
signedToken, _, err := signToken(issuerName, validKey, false)
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
cases := []struct {
desc string
key auth.Key
err error
desc string
key auth.Key
managerReq jwt.Token
managerResp []byte
managerErr error
err error
}{
{
desc: "issue new token",
key: key(),
err: nil,
desc: "issue new token",
key: validKey,
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token with OAuth token",
@@ -69,7 +67,8 @@ func TestIssue(t *testing.T) {
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
},
err: nil,
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token without a domain",
@@ -79,7 +78,8 @@ func TestIssue(t *testing.T) {
Subject: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
err: nil,
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token without a subject",
@@ -89,7 +89,8 @@ func TestIssue(t *testing.T) {
Subject: "",
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
err: nil,
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token without type",
@@ -99,7 +100,8 @@ func TestIssue(t *testing.T) {
Subject: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
err: nil,
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token without a domain and subject",
@@ -110,116 +112,161 @@ func TestIssue(t *testing.T) {
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
},
err: nil,
managerResp: []byte(signedToken),
err: nil,
},
{
desc: "issue token with failed to sign jwt",
key: validKey,
managerErr: svcerr.ErrAuthentication,
err: authjwt.ErrSignJWT,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
tc.managerReq = newToken(issuerName, tc.key)
kmCall := keyManager.On("SignJWT", tc.managerReq).Return(tc.managerResp, tc.managerErr)
tkn, err := tokenizer.Issue(tc.key)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err != nil {
if err == nil {
assert.NotEmpty(t, tkn, fmt.Sprintf("%s expected token, got empty string", tc.desc))
}
kmCall.Unset()
})
}
}
func TestParse(t *testing.T) {
tokenizer := authjwt.New([]byte(secret))
tokenizer := authjwt.New(keyManager)
token, err := tokenizer.Issue(key())
validKey := key()
signedTkn, parsedTkn, err := signToken(issuerName, validKey, true)
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
apiKey := key()
apiKey.Type = auth.APIKey
apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second)
apiToken, err := tokenizer.Issue(apiKey)
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
apiToken, _, err := signToken(issuerName, apiKey, false)
require.Nil(t, err, fmt.Sprintf("issuing api key expected to succeed: %s", err))
expKey := key()
expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second)
expToken, err := tokenizer.Issue(expKey)
expToken, _, err := signToken(issuerName, expKey, false)
require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err))
emptySubjectKey := key()
emptySubjectKey.Subject = ""
emptySubjectToken, err := tokenizer.Issue(emptySubjectKey)
signedEmptySubjectTkn, parsedEmptySubjectTkn, err := signToken(issuerName, emptySubjectKey, true)
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
emptyTypeKey := key()
emptyTypeKey.Type = auth.KeyType(auth.InvitationKey + 1)
emptyTypeToken, err := tokenizer.Issue(emptyTypeKey)
emptyTypeToken, _, err := signToken(issuerName, emptyTypeKey, false)
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
emptyKey := key()
emptyKey.Subject = ""
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
inValidToken := newToken("invalid", key())
signedInValidTkn, parsedInvalidTkn, err := signToken("invalid.issuer", key(), true)
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
cases := []struct {
desc string
key auth.Key
token string
err error
desc string
key auth.Key
token string
managerRes jwt.Token
managerErr error
err error
}{
{
desc: "parse valid key",
key: key(),
token: token,
err: nil,
desc: "parse valid key",
key: validKey,
token: signedTkn,
managerRes: parsedTkn,
err: nil,
},
{
desc: "parse invalid key",
key: auth.Key{},
token: "invalid",
err: svcerr.ErrAuthentication,
desc: "parse invalid key",
key: auth.Key{},
token: "invalid",
managerErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "parse expired key",
key: auth.Key{},
token: expToken,
err: auth.ErrExpiry,
desc: "parse expired key",
key: auth.Key{},
token: expToken,
managerErr: errJWTExpiryKey,
err: auth.ErrExpiry,
},
{
desc: "parse expired API key",
key: apiKey,
token: apiToken,
err: auth.ErrExpiry,
desc: "parse expired API key",
key: apiKey,
token: apiToken,
managerErr: errJWTExpiryKey,
err: auth.ErrExpiry,
},
{
desc: "parse token with invalid issuer",
key: auth.Key{},
token: inValidToken,
err: svcerr.ErrAuthentication,
desc: "parse token with invalid issuer",
key: auth.Key{},
token: signedInValidTkn,
managerRes: parsedInvalidTkn,
err: svcerr.ErrAuthentication,
},
{
desc: "parse token with invalid content",
key: auth.Key{},
token: newToken(issuerName, key()),
err: authjwt.ErrJSONHandle,
desc: "parse token with empty subject",
key: emptySubjectKey,
token: signedEmptySubjectTkn,
managerRes: parsedEmptySubjectTkn,
err: nil,
},
{
desc: "parse token with empty subject",
key: emptySubjectKey,
token: emptySubjectToken,
err: nil,
},
{
desc: "parse token with empty type",
key: emptyTypeKey,
token: emptyTypeToken,
err: svcerr.ErrAuthentication,
desc: "parse token with empty type",
key: emptyTypeKey,
token: emptyTypeToken,
managerRes: newToken(issuerName, emptyKey),
err: svcerr.ErrAuthentication,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
key, err := tokenizer.Parse(tc.token)
kmCall := keyManager.On("ParseJWT", tc.token).Return(tc.managerRes, tc.managerErr)
key, err := tokenizer.Parse(context.Background(), tc.token)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key))
}
kmCall.Unset()
})
}
}
func TestRetrieveJWKS(t *testing.T) {
tokenizer := authjwt.New(keyManager)
cases := []struct {
desc string
keys []auth.JWK
retrieveErr error
err error
}{
{
desc: "retrieve jwks with keys",
keys: []auth.JWK{newJWK(t), newJWK(t)},
},
{
desc: "retrieve empty jwks",
keys: []auth.JWK{},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
kmCall := keyManager.On("PublicJWKS", mock.Anything).Return(tc.keys, tc.retrieveErr)
jwks := tokenizer.RetrieveJWKS()
assert.Equal(t, tc.keys, jwks, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.keys, jwks))
kmCall.Unset()
})
}
}
@@ -236,3 +283,59 @@ func key() auth.Key {
ExpiresAt: exp,
}
}
func newToken(issuerName string, key auth.Key) jwt.Token {
builder := jwt.NewBuilder()
builder.
Issuer(issuerName).
IssuedAt(key.IssuedAt).
Claim(tokenType, key.Type).
Expiration(key.ExpiresAt)
builder.Claim(roleField, key.Role)
builder.Claim(VerifiedField, key.Verified)
if key.Subject != "" {
builder.Subject(key.Subject)
}
if key.ID != "" {
builder.JwtID(key.ID)
}
tkn, _ := builder.Build()
return tkn
}
func signToken(issuerName string, key auth.Key, parseToken bool) (string, jwt.Token, error) {
tkn := newToken(issuerName, key)
pKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
return "", nil, err
}
pubKey := &pKey.PublicKey
sTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.RS256, pKey))
if err != nil {
return "", nil, err
}
if !parseToken {
return string(sTkn), nil, nil
}
pTkn, err := jwt.Parse(
sTkn,
jwt.WithValidate(true),
jwt.WithKey(jwa.RS256, pubKey),
)
if err != nil {
return "", nil, err
}
return string(sTkn), pTkn, nil
}
func newJWK(t *testing.T) auth.JWK {
pKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err))
jwkKey, err := jwk.FromRaw(&pKey.PublicKey)
require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err))
err = jwkKey.Set(jwk.KeyIDKey, "test-key-id")
require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err))
err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String())
require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err))
return auth.NewJWK(jwkKey)
}
+18 -23
View File
@@ -10,7 +10,6 @@ import (
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)
@@ -34,27 +33,23 @@ var (
)
const (
issuerName = "supermq.auth"
tokenType = "type"
userField = "user"
RoleField = "role"
VerifiedField = "verified"
oauthProviderField = "oauth_provider"
oauthAccessTokenField = "access_token"
oauthRefreshTokenField = "refresh_token"
patPrefix = "pat"
issuerName = "supermq.auth"
tokenType = "type"
RoleField = "role"
VerifiedField = "verified"
patPrefix = "pat"
)
type tokenizer struct {
secret []byte
keyManager auth.KeyManager
}
var _ auth.Tokenizer = (*tokenizer)(nil)
// NewRepository instantiates an implementation of Token repository.
func New(secret []byte) auth.Tokenizer {
// New instantiates an implementation of Tokenizer service.
func New(keyManager auth.KeyManager) auth.Tokenizer {
return &tokenizer{
secret: secret,
keyManager: keyManager,
}
}
@@ -77,14 +72,14 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
signedTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.HS512, tok.secret))
signedTkn, err := tok.keyManager.SignJWT(tkn)
if err != nil {
return "", errors.Wrap(ErrSignJWT, err)
}
return string(signedTkn), nil
}
func (tok *tokenizer) Parse(token string) (auth.Key, error) {
func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) {
if len(token) >= 3 && token[:3] == patPrefix {
return auth.Key{Type: auth.PersonalAccessToken}, nil
}
@@ -94,7 +89,7 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
key, err := toKey(tkn)
key, err := ToKey(tkn)
if err != nil {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
@@ -103,11 +98,7 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) {
}
func (tok *tokenizer) validateToken(token string) (jwt.Token, error) {
tkn, err := jwt.Parse(
[]byte(token),
jwt.WithValidate(true),
jwt.WithKey(jwa.HS512, tok.secret),
)
tkn, err := tok.keyManager.ParseJWT(token)
if err != nil {
if errors.Contains(err, errJWTExpiryKey) {
return nil, auth.ErrExpiry
@@ -128,7 +119,11 @@ func (tok *tokenizer) validateToken(token string) (jwt.Token, error) {
return tkn, nil
}
func toKey(tkn jwt.Token) (auth.Key, error) {
func (tok *tokenizer) RetrieveJWKS() []auth.JWK {
return tok.keyManager.PublicJWKS()
}
func ToKey(tkn jwt.Token) (auth.Key, error) {
data, err := json.Marshal(tkn.PrivateClaims())
if err != nil {
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
+51
View File
@@ -0,0 +1,51 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"errors"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)
var (
ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm")
ErrInvalidSymmetricKey = errors.New("invalid symmetric key")
)
// JWK represents a JSON Web Key.
type JWK struct {
key jwk.Key
}
// NewJWK creates a new JWK from a jwk.Key.
func NewJWK(key jwk.Key) JWK {
return JWK{key: key}
}
// Key returns the underlying jwk.Key.
func (j JWK) Key() jwk.Key {
return j.key
}
// KeyManager represents a manager for JWT keys.
type KeyManager interface {
SignJWT(token jwt.Token) ([]byte, error)
ParseJWT(token string) (jwt.Token, error)
PublicJWKS() []JWK
}
func IsSymmetricAlgorithm(alg string) (bool, error) {
switch alg {
case "HS256", "HS384", "HS512":
return true, nil
case "EdDSA":
return false, nil
default:
return false, ErrUnsupportedKeyAlgorithm
}
}
+132
View File
@@ -0,0 +1,132 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package asymmetric
import (
"crypto/ed25519"
"crypto/x509"
"encoding/pem"
"errors"
"os"
"github.com/absmach/supermq"
"github.com/absmach/supermq/auth"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
)
var (
errLoadingPrivateKey = errors.New("failed to load private key")
errInvalidKeySize = errors.New("invalid ED25519 key size")
errParsingPrivateKey = errors.New("failed to parse private key")
errInvalidKeyType = errors.New("private key is not ED25519")
errGeneratingKID = errors.New("failed to generate key ID")
)
type manager struct {
privateKey jwk.Key
publicKey jwk.Key
kid string
}
var _ auth.KeyManager = (*manager)(nil)
// NewKeyManager creates a new asymmetric key manager that loads the private key from a file.
func NewKeyManager(privateKeyPath string, idProvider supermq.IDProvider) (auth.KeyManager, error) {
kid, err := idProvider.ID()
if err != nil {
return nil, errors.Join(errGeneratingKID, err)
}
privateJwk, publicJwk, err := loadKeyPair(privateKeyPath, kid)
if err != nil {
return nil, err
}
return &manager{
privateKey: privateJwk,
publicKey: publicJwk,
kid: kid,
}, nil
}
func (km *manager) SignJWT(token jwt.Token) ([]byte, error) {
return jwt.Sign(token, jwt.WithKey(jwa.EdDSA, km.privateKey))
}
func (km *manager) ParseJWT(token string) (jwt.Token, error) {
set := jwk.NewSet()
if err := set.AddKey(km.publicKey); err != nil {
return nil, err
}
tkn, err := jwt.Parse(
[]byte(token),
jwt.WithValidate(true),
jwt.WithKeySet(set, jws.WithInferAlgorithmFromKey(true)),
)
if err != nil {
return nil, err
}
return tkn, nil
}
func (km *manager) PublicJWKS() []auth.JWK {
return []auth.JWK{auth.NewJWK(km.publicKey)}
}
func loadKeyPair(privateKeyPath string, kid string) (jwk.Key, jwk.Key, error) {
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
return nil, nil, errors.Join(errLoadingPrivateKey, err)
}
var privateKey ed25519.PrivateKey
block, _ := pem.Decode(privateKeyBytes)
switch {
case block != nil:
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, nil, errors.Join(errParsingPrivateKey, err)
}
var ok bool
privateKey, ok = parsedKey.(ed25519.PrivateKey)
if !ok {
return nil, nil, errInvalidKeyType
}
default:
if len(privateKeyBytes) != ed25519.PrivateKeySize {
return nil, nil, errInvalidKeySize
}
privateKey = ed25519.PrivateKey(privateKeyBytes)
}
publicKey := privateKey.Public().(ed25519.PublicKey)
privateJwk, err := jwk.FromRaw(privateKey)
if err != nil {
return nil, nil, err
}
if err := privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
return nil, nil, err
}
if err := privateJwk.Set(jwk.KeyIDKey, kid); err != nil {
return nil, nil, err
}
publicJwk, err := jwk.FromRaw(publicKey)
if err != nil {
return nil, nil, err
}
if err := publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
return nil, nil, err
}
if err := publicJwk.Set(jwk.KeyIDKey, kid); err != nil {
return nil, nil, err
}
return privateJwk, publicJwk, nil
}
+47
View File
@@ -0,0 +1,47 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package symmetric
import (
"github.com/absmach/supermq/auth"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)
type manager struct {
algorithm jwa.KeyAlgorithm
secret []byte
}
var _ auth.KeyManager = (*manager)(nil)
func NewKeyManager(algorithm string, secret []byte) (auth.KeyManager, error) {
alg := jwa.KeyAlgorithmFrom(algorithm)
if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok {
return nil, auth.ErrUnsupportedKeyAlgorithm
}
if len(secret) == 0 {
return nil, auth.ErrInvalidSymmetricKey
}
return &manager{
secret: secret,
algorithm: alg,
}, nil
}
func (km *manager) SignJWT(token jwt.Token) ([]byte, error) {
return jwt.Sign(token, jwt.WithKey(km.algorithm, km.secret))
}
func (km *manager) ParseJWT(token string) (jwt.Token, error) {
return jwt.Parse(
[]byte(token),
jwt.WithValidate(true),
jwt.WithKey(km.algorithm, km.secret),
)
}
func (km *manager) PublicJWKS() []auth.JWK {
return nil
}
+10
View File
@@ -100,6 +100,16 @@ func (lm *loggingMiddleware) Identify(ctx context.Context, token string) (id aut
return lm.svc.Identify(ctx, token)
}
func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.JWK) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
}
lm.logger.Info("Retrieve JWKS completed successfully", args...)
}(time.Now())
return lm.svc.RetrieveJWKS()
}
func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy) (err error) {
defer func(begin time.Time) {
args := []any{
+8
View File
@@ -67,6 +67,14 @@ func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (auth.K
return ms.svc.Identify(ctx, token)
}
func (ms *metricsMiddleware) RetrieveJWKS() []auth.JWK {
defer func(begin time.Time) {
ms.counter.With("method", "retrieve_jwks").Add(1)
ms.latency.With("method", "retrieve_jwks").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RetrieveJWKS()
}
func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy) error {
defer func(begin time.Time) {
ms.counter.With("method", "authorize").Add(1)
+4
View File
@@ -61,6 +61,10 @@ func (tm *tracingMiddleware) Identify(ctx context.Context, token string) (auth.K
return tm.svc.Identify(ctx, token)
}
func (tm *tracingMiddleware) RetrieveJWKS() []auth.JWK {
return tm.svc.RetrieveJWKS()
}
func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy) error {
ctx, span := tm.tracer.Start(ctx, "authorize", trace.WithAttributes(
attribute.String("subject", pr.Subject),
+212
View File
@@ -0,0 +1,212 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"github.com/absmach/supermq/auth"
"github.com/lestrrat-go/jwx/v2/jwt"
mock "github.com/stretchr/testify/mock"
)
// NewKeyManager creates a new instance of KeyManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewKeyManager(t interface {
mock.TestingT
Cleanup(func())
}) *KeyManager {
mock := &KeyManager{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// KeyManager is an autogenerated mock type for the KeyManager type
type KeyManager struct {
mock.Mock
}
type KeyManager_Expecter struct {
mock *mock.Mock
}
func (_m *KeyManager) EXPECT() *KeyManager_Expecter {
return &KeyManager_Expecter{mock: &_m.Mock}
}
// ParseJWT provides a mock function for the type KeyManager
func (_mock *KeyManager) ParseJWT(token string) (jwt.Token, error) {
ret := _mock.Called(token)
if len(ret) == 0 {
panic("no return value specified for ParseJWT")
}
var r0 jwt.Token
var r1 error
if returnFunc, ok := ret.Get(0).(func(string) (jwt.Token, error)); ok {
return returnFunc(token)
}
if returnFunc, ok := ret.Get(0).(func(string) jwt.Token); ok {
r0 = returnFunc(token)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(jwt.Token)
}
}
if returnFunc, ok := ret.Get(1).(func(string) error); ok {
r1 = returnFunc(token)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// KeyManager_ParseJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ParseJWT'
type KeyManager_ParseJWT_Call struct {
*mock.Call
}
// ParseJWT is a helper method to define mock.On call
// - token string
func (_e *KeyManager_Expecter) ParseJWT(token interface{}) *KeyManager_ParseJWT_Call {
return &KeyManager_ParseJWT_Call{Call: _e.mock.On("ParseJWT", token)}
}
func (_c *KeyManager_ParseJWT_Call) Run(run func(token string)) *KeyManager_ParseJWT_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 string
if args[0] != nil {
arg0 = args[0].(string)
}
run(
arg0,
)
})
return _c
}
func (_c *KeyManager_ParseJWT_Call) Return(token1 jwt.Token, err error) *KeyManager_ParseJWT_Call {
_c.Call.Return(token1, err)
return _c
}
func (_c *KeyManager_ParseJWT_Call) RunAndReturn(run func(token string) (jwt.Token, error)) *KeyManager_ParseJWT_Call {
_c.Call.Return(run)
return _c
}
// PublicJWKS provides a mock function for the type KeyManager
func (_mock *KeyManager) PublicJWKS() []auth.JWK {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for PublicJWKS")
}
var r0 []auth.JWK
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
r0 = returnFunc()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]auth.JWK)
}
}
return r0
}
// KeyManager_PublicJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublicJWKS'
type KeyManager_PublicJWKS_Call struct {
*mock.Call
}
// PublicJWKS is a helper method to define mock.On call
func (_e *KeyManager_Expecter) PublicJWKS() *KeyManager_PublicJWKS_Call {
return &KeyManager_PublicJWKS_Call{Call: _e.mock.On("PublicJWKS")}
}
func (_c *KeyManager_PublicJWKS_Call) Run(run func()) *KeyManager_PublicJWKS_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *KeyManager_PublicJWKS_Call) Return(jWKs []auth.JWK) *KeyManager_PublicJWKS_Call {
_c.Call.Return(jWKs)
return _c
}
func (_c *KeyManager_PublicJWKS_Call) RunAndReturn(run func() []auth.JWK) *KeyManager_PublicJWKS_Call {
_c.Call.Return(run)
return _c
}
// SignJWT provides a mock function for the type KeyManager
func (_mock *KeyManager) SignJWT(token jwt.Token) ([]byte, error) {
ret := _mock.Called(token)
if len(ret) == 0 {
panic("no return value specified for SignJWT")
}
var r0 []byte
var r1 error
if returnFunc, ok := ret.Get(0).(func(jwt.Token) ([]byte, error)); ok {
return returnFunc(token)
}
if returnFunc, ok := ret.Get(0).(func(jwt.Token) []byte); ok {
r0 = returnFunc(token)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
if returnFunc, ok := ret.Get(1).(func(jwt.Token) error); ok {
r1 = returnFunc(token)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// KeyManager_SignJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignJWT'
type KeyManager_SignJWT_Call struct {
*mock.Call
}
// SignJWT is a helper method to define mock.On call
// - token jwt.Token
func (_e *KeyManager_Expecter) SignJWT(token interface{}) *KeyManager_SignJWT_Call {
return &KeyManager_SignJWT_Call{Call: _e.mock.On("SignJWT", token)}
}
func (_c *KeyManager_SignJWT_Call) Run(run func(token jwt.Token)) *KeyManager_SignJWT_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 jwt.Token
if args[0] != nil {
arg0 = args[0].(jwt.Token)
}
run(
arg0,
)
})
return _c
}
func (_c *KeyManager_SignJWT_Call) Return(bytes []byte, err error) *KeyManager_SignJWT_Call {
_c.Call.Return(bytes, err)
return _c
}
func (_c *KeyManager_SignJWT_Call) RunAndReturn(run func(token jwt.Token) ([]byte, error)) *KeyManager_SignJWT_Call {
_c.Call.Return(run)
return _c
}
+46
View File
@@ -1028,6 +1028,52 @@ func (_c *Service_ResetPATSecret_Call) RunAndReturn(run func(ctx context.Context
return _c
}
// RetrieveJWKS provides a mock function for the type Service
func (_mock *Service) RetrieveJWKS() []auth.JWK {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for RetrieveJWKS")
}
var r0 []auth.JWK
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
r0 = returnFunc()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]auth.JWK)
}
}
return r0
}
// Service_RetrieveJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveJWKS'
type Service_RetrieveJWKS_Call struct {
*mock.Call
}
// RetrieveJWKS is a helper method to define mock.On call
func (_e *Service_Expecter) RetrieveJWKS() *Service_RetrieveJWKS_Call {
return &Service_RetrieveJWKS_Call{Call: _e.mock.On("RetrieveJWKS")}
}
func (_c *Service_RetrieveJWKS_Call) Run(run func()) *Service_RetrieveJWKS_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *Service_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Service_RetrieveJWKS_Call {
_c.Call.Return(jWKs)
return _c
}
func (_c *Service_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Service_RetrieveJWKS_Call {
_c.Call.Return(run)
return _c
}
// RetrieveKey provides a mock function for the type Service
func (_mock *Service) RetrieveKey(ctx context.Context, token string, id string) (auth.Key, error) {
ret := _mock.Called(ctx, token, id)
+215
View File
@@ -0,0 +1,215 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"context"
"github.com/absmach/supermq/auth"
mock "github.com/stretchr/testify/mock"
)
// NewTokenizer creates a new instance of Tokenizer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewTokenizer(t interface {
mock.TestingT
Cleanup(func())
}) *Tokenizer {
mock := &Tokenizer{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Tokenizer is an autogenerated mock type for the Tokenizer type
type Tokenizer struct {
mock.Mock
}
type Tokenizer_Expecter struct {
mock *mock.Mock
}
func (_m *Tokenizer) EXPECT() *Tokenizer_Expecter {
return &Tokenizer_Expecter{mock: &_m.Mock}
}
// Issue provides a mock function for the type Tokenizer
func (_mock *Tokenizer) Issue(key auth.Key) (string, error) {
ret := _mock.Called(key)
if len(ret) == 0 {
panic("no return value specified for Issue")
}
var r0 string
var r1 error
if returnFunc, ok := ret.Get(0).(func(auth.Key) (string, error)); ok {
return returnFunc(key)
}
if returnFunc, ok := ret.Get(0).(func(auth.Key) string); ok {
r0 = returnFunc(key)
} else {
r0 = ret.Get(0).(string)
}
if returnFunc, ok := ret.Get(1).(func(auth.Key) error); ok {
r1 = returnFunc(key)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Tokenizer_Issue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Issue'
type Tokenizer_Issue_Call struct {
*mock.Call
}
// Issue is a helper method to define mock.On call
// - key auth.Key
func (_e *Tokenizer_Expecter) Issue(key interface{}) *Tokenizer_Issue_Call {
return &Tokenizer_Issue_Call{Call: _e.mock.On("Issue", key)}
}
func (_c *Tokenizer_Issue_Call) Run(run func(key auth.Key)) *Tokenizer_Issue_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 auth.Key
if args[0] != nil {
arg0 = args[0].(auth.Key)
}
run(
arg0,
)
})
return _c
}
func (_c *Tokenizer_Issue_Call) Return(token string, err error) *Tokenizer_Issue_Call {
_c.Call.Return(token, err)
return _c
}
func (_c *Tokenizer_Issue_Call) RunAndReturn(run func(key auth.Key) (string, error)) *Tokenizer_Issue_Call {
_c.Call.Return(run)
return _c
}
// Parse provides a mock function for the type Tokenizer
func (_mock *Tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) {
ret := _mock.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for Parse")
}
var r0 auth.Key
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (auth.Key, error)); ok {
return returnFunc(ctx, token)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string) auth.Key); ok {
r0 = returnFunc(ctx, token)
} else {
r0 = ret.Get(0).(auth.Key)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = returnFunc(ctx, token)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Tokenizer_Parse_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Parse'
type Tokenizer_Parse_Call struct {
*mock.Call
}
// Parse is a helper method to define mock.On call
// - ctx context.Context
// - token string
func (_e *Tokenizer_Expecter) Parse(ctx interface{}, token interface{}) *Tokenizer_Parse_Call {
return &Tokenizer_Parse_Call{Call: _e.mock.On("Parse", ctx, token)}
}
func (_c *Tokenizer_Parse_Call) Run(run func(ctx context.Context, token string)) *Tokenizer_Parse_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Tokenizer_Parse_Call) Return(key auth.Key, err error) *Tokenizer_Parse_Call {
_c.Call.Return(key, err)
return _c
}
func (_c *Tokenizer_Parse_Call) RunAndReturn(run func(ctx context.Context, token string) (auth.Key, error)) *Tokenizer_Parse_Call {
_c.Call.Return(run)
return _c
}
// RetrieveJWKS provides a mock function for the type Tokenizer
func (_mock *Tokenizer) RetrieveJWKS() []auth.JWK {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for RetrieveJWKS")
}
var r0 []auth.JWK
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
r0 = returnFunc()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]auth.JWK)
}
}
return r0
}
// Tokenizer_RetrieveJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveJWKS'
type Tokenizer_RetrieveJWKS_Call struct {
*mock.Call
}
// RetrieveJWKS is a helper method to define mock.On call
func (_e *Tokenizer_Expecter) RetrieveJWKS() *Tokenizer_RetrieveJWKS_Call {
return &Tokenizer_RetrieveJWKS_Call{Call: _e.mock.On("RetrieveJWKS")}
}
func (_c *Tokenizer_RetrieveJWKS_Call) Run(run func()) *Tokenizer_RetrieveJWKS_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *Tokenizer_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Tokenizer_RetrieveJWKS_Call {
_c.Call.Return(jWKs)
return _c
}
func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Tokenizer_RetrieveJWKS_Call {
_c.Call.Return(run)
return _c
}
+14 -7
View File
@@ -83,6 +83,9 @@ type Authn interface {
// is returned. If token is invalid, or invocation failed for some
// other reason, non-nil error value is returned in response.
Identify(ctx context.Context, token string) (Key, error)
// RetrieveJWKS retrieves a JWKs to validate issued tokens.
RetrieveJWKS() []JWK
}
// Service specifies an API that must be fulfilled by the domain service
@@ -145,7 +148,7 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err
}
func (svc service) Revoke(ctx context.Context, token, id string) error {
issuerID, _, err := svc.authenticate(token)
issuerID, _, err := svc.authenticate(ctx, token)
if err != nil {
return errors.Wrap(errRevoke, err)
}
@@ -156,7 +159,7 @@ func (svc service) Revoke(ctx context.Context, token, id string) error {
}
func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, error) {
issuerID, _, err := svc.authenticate(token)
issuerID, _, err := svc.authenticate(ctx, token)
if err != nil {
return Key{}, errors.Wrap(errRetrieve, err)
}
@@ -169,7 +172,7 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro
}
func (svc service) Identify(ctx context.Context, token string) (Key, error) {
key, err := svc.tokenizer.Parse(token)
key, err := svc.tokenizer.Parse(ctx, token)
if errors.Contains(err, ErrExpiry) {
err = svc.keys.Remove(ctx, key.Issuer, key.ID)
return Key{}, errors.Wrap(svcerr.ErrAuthentication, errors.Wrap(ErrKeyExpired, err))
@@ -198,6 +201,10 @@ func (svc service) Identify(ctx context.Context, token string) (Key, error) {
}
}
func (svc service) RetrieveJWKS() []JWK {
return svc.tokenizer.RetrieveJWKS()
}
func (svc service) Authorize(ctx context.Context, pr policies.Policy) error {
if pr.PatID != "" && pr.TokenType == PersonalAccessTokenType {
if err := svc.AuthorizePAT(ctx, pr.UserID, pr.PatID, EntityType(pr.EntityType), pr.OptionalDomainID, Operation(pr.Operation), pr.EntityID); err != nil {
@@ -325,7 +332,7 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) {
}
func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token, error) {
k, err := svc.tokenizer.Parse(token)
k, err := svc.tokenizer.Parse(ctx, token)
if err != nil {
return Token{}, errors.Wrap(errRetrieve, err)
}
@@ -402,7 +409,7 @@ func (svc service) getUserRole(ctx context.Context, userID string) (role Role) {
}
func (svc service) userKey(ctx context.Context, token string, key Key) (Token, error) {
id, sub, err := svc.authenticate(token)
id, sub, err := svc.authenticate(ctx, token)
if err != nil {
return Token{}, errors.Wrap(errIssueUser, err)
}
@@ -433,8 +440,8 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e
return Token{AccessToken: tkn}, nil
}
func (svc service) authenticate(token string) (string, string, error) {
key, err := svc.tokenizer.Parse(token)
func (svc service) authenticate(ctx context.Context, token string) (string, string, error) {
key, err := svc.tokenizer.Parse(ctx, token)
if err != nil {
return "", "", errors.Wrap(svcerr.ErrAuthentication, err)
}
+426 -262
View File
File diff suppressed because it is too large Load Diff
+8 -1
View File
@@ -3,11 +3,18 @@
package auth
import (
"context"
)
// Tokenizer specifies API for encoding and decoding between string and Key.
type Tokenizer interface {
// Issue converts API Key to its string representation.
Issue(key Key) (token string, err error)
// Parse extracts API Key data from string token.
Parse(token string) (key Key, err error)
Parse(ctx context.Context, token string) (key Key, err error)
// RetrieveJWKS returns the JSON Web Key Set.
RetrieveJWKS() []JWK
}
+54 -22
View File
@@ -23,6 +23,8 @@ import (
"github.com/absmach/supermq/auth/cache"
"github.com/absmach/supermq/auth/hasher"
"github.com/absmach/supermq/auth/jwt"
"github.com/absmach/supermq/auth/keymanager/asymmetric"
"github.com/absmach/supermq/auth/keymanager/symmetric"
"github.com/absmach/supermq/auth/middleware"
apostgres "github.com/absmach/supermq/auth/postgres"
redisclient "github.com/absmach/supermq/internal/clients/redis"
@@ -59,22 +61,26 @@ const (
)
type config struct {
LogLevel string `env:"SMQ_AUTH_LOG_LEVEL" envDefault:"info"`
SecretKey string `env:"SMQ_AUTH_SECRET_KEY" envDefault:"secret"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_AUTH_ADAPTER_INSTANCE_ID" envDefault:""`
AccessDuration time.Duration `env:"SMQ_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"`
RefreshDuration time.Duration `env:"SMQ_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"`
InvitationDuration time.Duration `env:"SMQ_AUTH_INVITATION_DURATION" envDefault:"168h"`
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"./docker/spicedb/schema.zed"`
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
CacheURL string `env:"SMQ_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"`
CacheKeyDuration time.Duration `env:"SMQ_AUTH_CACHE_KEY_DURATION" envDefault:"10m"`
LogLevel string `env:"SMQ_AUTH_LOG_LEVEL" envDefault:"info"`
SecretKey string `env:"SMQ_AUTH_SECRET_KEY" envDefault:"secret"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_AUTH_ADAPTER_INSTANCE_ID" envDefault:""`
AccessDuration time.Duration `env:"SMQ_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"`
RefreshDuration time.Duration `env:"SMQ_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"`
KeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"EdDSA"`
PrivateKeyPath string `env:"SMQ_AUTH_KEYS_PRIVATE_KEY_PATH" envDefault:"./ssl/keys/private.pem"`
InvitationDuration time.Duration `env:"SMQ_AUTH_INVITATION_DURATION" envDefault:"168h"`
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"./docker/spicedb/schema.zed"`
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
CacheURL string `env:"SMQ_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"`
CacheKeyDuration time.Duration `env:"SMQ_AUTH_CACHE_KEY_DURATION" envDefault:"10m"`
JWKSCacheMaxAge int `env:"SMQ_AUTH_JWKS_CACHE_MAX_AGE" envDefault:"900"`
JWKSCacheStaleWhileRevalidate int `env:"SMQ_AUTH_JWKS_CACHE_STALE_WHILE_REVALIDATE" envDefault:"60"`
}
func main() {
@@ -144,7 +150,34 @@ func main() {
return
}
svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration)
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.KeyAlgorithm)
if err != nil {
logger.Error(fmt.Sprintf("failed to determine key algorithm type: %s", err.Error()))
exitCode = 1
return
}
idProvider := uuid.New()
var keyManager auth.KeyManager
switch {
case isSymmetric:
keyManager, err = symmetric.NewKeyManager(cfg.KeyAlgorithm, []byte(cfg.SecretKey))
if err != nil {
logger.Error(fmt.Sprintf("failed to create symmetric key manager: %s", err.Error()))
exitCode = 1
return
}
default:
keyManager, err = asymmetric.NewKeyManager(cfg.PrivateKeyPath, idProvider)
if err != nil {
logger.Error(fmt.Sprintf("failed to create asymmetric key manager: %s", err.Error()))
exitCode = 1
return
}
}
svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration, keyManager, idProvider)
if err != nil {
logger.Error(fmt.Sprintf("failed to create service : %s\n", err.Error()))
exitCode = 1
@@ -180,7 +213,7 @@ func main() {
exitCode = 1
return
}
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, logger, cfg.InstanceID), logger)
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, logger, cfg.InstanceID, cfg.JWKSCacheMaxAge, cfg.JWKSCacheStaleWhileRevalidate), logger)
g.Go(func() error {
return hs.Start()
@@ -225,21 +258,20 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch
return nil
}
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration) (auth.Service, error) {
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, keyManager auth.KeyManager, idProvider supermq.IDProvider) (auth.Service, error) {
cache := cache.NewPatsCache(cacheClient, keyDuration)
database := pgclient.NewDatabase(db, dbConfig, tracer)
keysRepo := apostgres.New(database)
patsRepo := apostgres.NewPatRepo(database, cache)
hasher := hasher.New()
idProvider := uuid.New()
pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger)
pService := spicedb.NewPolicyService(spicedbClient, logger)
t := jwt.New([]byte(cfg.SecretKey))
tokenizer := jwt.New(keyManager)
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc = middleware.NewLogging(svc, logger)
counter, latency := prometheus.MakeMetrics("auth", "api")
svc = middleware.NewMetrics(svc, counter, latency)
+29 -4
View File
@@ -18,6 +18,7 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/channels"
grpcapi "github.com/absmach/supermq/channels/api/grpc"
httpapi "github.com/absmach/supermq/channels/api/http"
@@ -35,6 +36,7 @@ import (
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
"github.com/absmach/supermq/pkg/callout"
@@ -98,6 +100,8 @@ type config struct {
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"`
}
@@ -176,14 +180,35 @@ func main() {
exitCode = 1
return
}
authn, authnClient, err := authsvcAuthn.NewAuthentication(ctx, grpcCfg)
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error(err.Error())
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
var authn smqauthn.Authentication
var authnClient grpcclient.Handler
switch {
case !isSymmetric:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, grpcCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, grpcCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
+30 -5
View File
@@ -18,6 +18,7 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/clients"
grpcapi "github.com/absmach/supermq/clients/api/grpc"
httpapi "github.com/absmach/supermq/clients/api/http"
@@ -34,6 +35,7 @@ import (
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
"github.com/absmach/supermq/pkg/callout"
@@ -99,6 +101,8 @@ type config struct {
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"`
}
@@ -186,14 +190,35 @@ func main() {
exitCode = 1
return
}
authn, authnClient, err := authsvcAuthn.NewAuthentication(ctx, grpcCfg)
alg, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error(err.Error())
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
var authn smqauthn.Authentication
var authnClient grpcclient.Handler
switch {
case !alg:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, grpcCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, grpcCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
@@ -224,7 +249,7 @@ func main() {
return
}
defer authzClient.Close()
logger.Info("AuthZ successfully connected to auth gRPC server " + authnClient.Secure())
logger.Info("AuthZ successfully connected to auth gRPC server " + authzClient.Secure())
chgrpccfg := grpcclient.Config{}
if err := env.ParseWithOptions(&chgrpccfg, env.Options{Prefix: envPrefixChannels}); err != nil {
+28 -4
View File
@@ -15,6 +15,7 @@ import (
chclient "github.com/absmach/callhome/pkg/client"
"github.com/absmach/supermq"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/domains"
domainsSvc "github.com/absmach/supermq/domains"
domainsgrpcapi "github.com/absmach/supermq/domains/api/grpc"
@@ -28,6 +29,7 @@ import (
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
"github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
"github.com/absmach/supermq/pkg/callout"
@@ -83,6 +85,8 @@ type config struct {
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"`
}
@@ -153,14 +157,34 @@ func main() {
return
}
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, clientConfig)
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error(fmt.Sprintf("authn failed to connect to auth gRPC server : %s", err.Error()))
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
defer authnHandler.Close()
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
var authn smqauthn.Authentication
var authnClient grpcclient.Handler
switch {
case !isSymmetric:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, clientConfig)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, clientConfig)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
database := postgres.NewDatabase(db, dbConfig, tracer)
+28 -4
View File
@@ -17,6 +17,7 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/domains"
dpostgres "github.com/absmach/supermq/domains/postgres"
"github.com/absmach/supermq/groups"
@@ -30,6 +31,7 @@ import (
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
"github.com/absmach/supermq/pkg/callout"
@@ -89,6 +91,8 @@ type config struct {
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"`
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"`
}
@@ -157,14 +161,34 @@ func main() {
return
}
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientConfig)
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error("failed to create authn " + err.Error())
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
defer authnHandler.Close()
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
var authn smqauthn.Authentication
var authnClient grpcclient.Handler
switch {
case !isSymmetric:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authClientConfig)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authClientConfig)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
+36 -12
View File
@@ -22,11 +22,13 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
"github.com/absmach/supermq/auth"
adapter "github.com/absmach/supermq/http"
httpapi "github.com/absmach/supermq/http/api"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authn/authsvc"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
"github.com/absmach/supermq/pkg/grpcclient"
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
@@ -60,13 +62,15 @@ const (
)
type config struct {
LogLevel string `env:"SMQ_HTTP_ADAPTER_LOG_LEVEL" envDefault:"info"`
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_HTTP_ADAPTER_INSTANCE_ID" envDefault:""`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
LogLevel string `env:"SMQ_HTTP_ADAPTER_LOG_LEVEL" envDefault:"info"`
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_HTTP_ADAPTER_INSTANCE_ID" envDefault:""`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
}
func main() {
@@ -163,14 +167,34 @@ func main() {
return
}
authn, authnHandler, err := authsvc.NewAuthentication(ctx, authnCfg)
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error(err.Error())
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
defer authnHandler.Close()
logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure())
var authn smqauthn.Authentication
var authnClient grpcclient.Handler
switch {
case !isSymmetric:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authnCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authnCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
if err != nil {
+34 -10
View File
@@ -14,6 +14,7 @@ import (
chclient "github.com/absmach/callhome/pkg/client"
"github.com/absmach/supermq"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/journal"
httpapi "github.com/absmach/supermq/journal/api"
"github.com/absmach/supermq/journal/events"
@@ -22,6 +23,7 @@ import (
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
@@ -51,12 +53,14 @@ const (
)
type config struct {
LogLevel string `env:"SMQ_JOURNAL_LOG_LEVEL" envDefault:"info"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_JOURNAL_INSTANCE_ID" envDefault:""`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
LogLevel string `env:"SMQ_JOURNAL_LOG_LEVEL" envDefault:"info"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_JOURNAL_INSTANCE_ID" envDefault:""`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
}
func main() {
@@ -105,14 +109,34 @@ func main() {
return
}
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientCfg)
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error(err.Error())
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
defer authnHandler.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
var authn smqauthn.Authentication
var authnClient grpcclient.Handler
switch {
case !isSymmetric:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authClientCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authClientCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
+28 -6
View File
@@ -19,10 +19,12 @@ import (
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
grpcUsersV1 "github.com/absmach/supermq/api/grpc/users/v1"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/internal/email"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
@@ -97,6 +99,8 @@ type config struct {
PasswordResetEmailTemplate string `env:"SMQ_PASSWORD_RESET_EMAIL_TEMPLATE" envDefault:"reset-password-email.tmpl"`
VerificationURLPrefix string `env:"SMQ_VERIFICATION_URL_PREFIX" envDefault:"http://localhost/verify-email"`
VerificationEmailTemplate string `env:"SMQ_VERIFICATION_EMAIL_TEMPLATE" envDefault:"verification-email.tmpl"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
PassRegex *regexp.Regexp
}
@@ -194,17 +198,35 @@ func main() {
defer tokenHandler.Close()
logger.Info("Token service client successfully connected to auth gRPC server " + tokenHandler.Secure())
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientConfig)
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error("failed to create authn " + err.Error())
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
defer authnHandler.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
var authn smqauthn.Authentication
var authnClient grpcclient.Handler
switch {
case !isSymmetric:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authClientConfig)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authClientConfig)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err))
+36 -12
View File
@@ -19,9 +19,11 @@ import (
"github.com/absmach/supermq"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
"github.com/absmach/supermq/auth"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authn/authsvc"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
"github.com/absmach/supermq/pkg/grpcclient"
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
@@ -56,13 +58,15 @@ const (
)
type config struct {
LogLevel string `env:"SMQ_WS_ADAPTER_LOG_LEVEL" envDefault:"info"`
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_WS_ADAPTER_INSTANCE_ID" envDefault:""`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
LogLevel string `env:"SMQ_WS_ADAPTER_LOG_LEVEL" envDefault:"info"`
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_WS_ADAPTER_INSTANCE_ID" envDefault:""`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
}
func main() {
@@ -158,14 +162,34 @@ func main() {
return
}
authn, authnHandler, err := authsvc.NewAuthentication(ctx, authnCfg)
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error(err.Error())
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
defer authnHandler.Close()
logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure())
var authn authn.Authentication
var authnClient grpcclient.Handler
switch {
case !isSymmetric:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authnCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authnCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
if err != nil {
+5 -1
View File
@@ -98,13 +98,17 @@ SMQ_AUTH_DB_SSL_MODE=disable
SMQ_AUTH_DB_SSL_CERT=
SMQ_AUTH_DB_SSL_KEY=
SMQ_AUTH_DB_SSL_ROOT_CERT=
SMQ_AUTH_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH
SMQ_AUTH_ACCESS_TOKEN_DURATION="1h"
SMQ_AUTH_REFRESH_TOKEN_DURATION="24h"
SMQ_AUTH_KEYS_ALGORITHM="EdDSA"
SMQ_AUTH_KEYS_PRIVATE_KEY_PATH="./ssl/keys/private.key"
SMQ_AUTH_INVITATION_DURATION="168h"
SMQ_AUTH_ADAPTER_INSTANCE_ID=
SMQ_AUTH_CACHE_URL=redis://auth-redis:${SMQ_REDIS_TCP_PORT}/0
SMQ_AUTH_CACHE_KEY_DURATION=10m
SMQ_AUTH_JWKS_URL=http://${SMQ_AUTH_HTTP_HOST}:${SMQ_AUTH_HTTP_PORT}/keys/.well-known/jwks.json
SMQ_AUTH_JWKS_CACHE_MAX_AGE=900
SMQ_AUTH_JWKS_CACHE_STALE_WHILE_REVALIDATE=60
#### Client Callout
SMQ_CLIENTS_CALLOUT_URLS=""
+19 -3
View File
@@ -23,6 +23,7 @@ volumes:
supermq-domains-db-volume:
supermq-domains-redis-volume:
supermq-auth-redis-volume:
supermq-auth-keys-volume:
services:
spicedb:
@@ -110,16 +111,17 @@ services:
SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY}
SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST}
SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT}
SMQ_AUTH_ACCESS_TOKEN_DURATION: ${SMQ_AUTH_ACCESS_TOKEN_DURATION}
SMQ_AUTH_REFRESH_TOKEN_DURATION: ${SMQ_AUTH_REFRESH_TOKEN_DURATION}
SMQ_AUTH_INVITATION_DURATION: ${SMQ_AUTH_INVITATION_DURATION}
SMQ_AUTH_SECRET_KEY: ${SMQ_AUTH_SECRET_KEY}
SMQ_AUTH_HTTP_HOST: ${SMQ_AUTH_HTTP_HOST}
SMQ_AUTH_HTTP_PORT: ${SMQ_AUTH_HTTP_PORT}
SMQ_AUTH_HTTP_SERVER_CERT: ${SMQ_AUTH_HTTP_SERVER_CERT}
SMQ_AUTH_HTTP_SERVER_KEY: ${SMQ_AUTH_HTTP_SERVER_KEY}
SMQ_AUTH_GRPC_HOST: ${SMQ_AUTH_GRPC_HOST}
SMQ_AUTH_GRPC_PORT: ${SMQ_AUTH_GRPC_PORT}
SMQ_AUTH_ACCESS_TOKEN_DURATION: ${SMQ_AUTH_ACCESS_TOKEN_DURATION}
SMQ_AUTH_REFRESH_TOKEN_DURATION: ${SMQ_AUTH_REFRESH_TOKEN_DURATION}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_AUTH_KEYS_PRIVATE_KEY_PATH: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:+/keys/private.key}
## Compose supports parameter expansion in environment,
## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty
## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default
@@ -150,6 +152,13 @@ services:
volumes:
- ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE}
- supermq-pat-db-volume:/supermq-data
- supermq-auth-keys-volume:/keys
# Auth private key file
- type: bind
source: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:-ssl/certs/dummy/private_key}
target: /keys/private.key
bind:
create_host_path: true
# Auth gRPC mTLS server certificates
- type: bind
source: ${SMQ_AUTH_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert}
@@ -258,6 +267,7 @@ services:
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_GROUPS_GRPC_URL: ${SMQ_GROUPS_GRPC_URL}
SMQ_GROUPS_GRPC_TIMEOUT: ${SMQ_GROUPS_GRPC_TIMEOUT}
SMQ_GROUPS_GRPC_CLIENT_CERT: ${SMQ_GROUPS_GRPC_CLIENT_CERT:+/groups-grpc-client.crt}
@@ -490,6 +500,7 @@ services:
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_CHANNELS_URL: ${SMQ_CHANNELS_URL}
SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL}
SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT}
@@ -683,6 +694,7 @@ services:
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL}
SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT}
SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt}
@@ -883,6 +895,7 @@ services:
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL}
SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT}
SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt}
@@ -1069,6 +1082,7 @@ services:
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY}
SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST}
SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT}
@@ -1323,6 +1337,7 @@ services:
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL}
SMQ_JAEGER_URL: ${SMQ_JAEGER_URL}
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
@@ -1553,6 +1568,7 @@ services:
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL}
SMQ_JAEGER_URL: ${SMQ_JAEGER_URL}
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
+3
View File
@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIPB+6hA+8rK067SdVlkWzgtxEUNysMhFFFzmGsKB1BAl
-----END PRIVATE KEY-----
+2
View File
@@ -19,6 +19,8 @@ const (
PersonalAccessToken
)
const PatPrefix = "pat_"
func (t TokenType) String() string {
switch t {
case AccessToken:
+1 -3
View File
@@ -15,8 +15,6 @@ import (
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
const patPrefix = "pat_"
type authentication struct {
authSvcClient grpcAuthV1.AuthServiceClient
}
@@ -46,7 +44,7 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
}
if strings.HasPrefix(token, patPrefix) {
if strings.HasPrefix(token, authn.PatPrefix) {
return authn.Session{Type: authn.PersonalAccessToken, PatID: res.GetId(), UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
}
+174
View File
@@ -0,0 +1,174 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package jwks
import (
"context"
"io"
"net/http"
"strings"
"sync"
"time"
grpcAuthV1 "github.com/absmach/supermq/api/grpc/auth/v1"
smqauth "github.com/absmach/supermq/auth"
"github.com/absmach/supermq/auth/api/grpc/auth"
smqjwt "github.com/absmach/supermq/auth/jwt"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/grpcclient"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
const (
issuerName = "supermq.auth"
cacheDuration = 5 * time.Minute
)
var (
// errJWTExpiryKey is used to check if the token is expired.
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
// errFetchJWKS indicates an error fetching JWKS from URL.
errFetchJWKS = errors.New("failed to fetch jwks")
// errInvalidIssuer indicates an invalid issuer value.
errInvalidIssuer = errors.New("invalid token issuer value")
// ErrValidateJWTToken indicates a failure to validate JWT token.
errValidateJWTToken = errors.New("failed to validate jwt token")
jwksCache = struct {
sync.RWMutex
jwks jwk.Set
cachedAt time.Time
}{}
)
var _ authn.Authentication = (*authentication)(nil)
type authentication struct {
jwksURL string
authSvcClient grpcAuthV1.AuthServiceClient
}
func NewAuthentication(ctx context.Context, jwksURL string, cfg grpcclient.Config) (authn.Authentication, grpcclient.Handler, error) {
client, err := grpcclient.NewHandler(cfg)
if err != nil {
return nil, nil, err
}
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "auth",
})
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
return nil, nil, grpcclient.ErrSvcNotServing
}
authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout)
return authentication{
jwksURL: jwksURL,
authSvcClient: authSvcClient,
}, client, nil
}
func (a authentication) Authenticate(ctx context.Context, token string) (authn.Session, error) {
if strings.HasPrefix(token, authn.PatPrefix) {
res, err := a.authSvcClient.Authenticate(ctx, &grpcAuthV1.AuthNReq{Token: token})
if err != nil {
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
return authn.Session{Type: authn.PersonalAccessToken, PatID: res.GetId(), UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
}
jwks, err := a.fetchJWKS(ctx)
if err != nil {
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
tkn, err := validateToken(token, jwks)
if err != nil {
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
key, err := smqjwt.ToKey(tkn)
if err != nil {
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
return authn.Session{
Type: authn.AccessToken,
UserID: key.Subject,
Role: authn.Role(key.Role),
Verified: key.Verified,
}, nil
}
func (a authentication) fetchJWKS(ctx context.Context) (jwk.Set, error) {
jwksCache.RLock()
if time.Since(jwksCache.cachedAt) < cacheDuration && jwksCache.jwks.Len() > 0 {
cached := jwksCache.jwks
jwksCache.RUnlock()
return cached, nil
}
jwksCache.RUnlock()
req, err := http.NewRequestWithContext(ctx, "GET", a.jwksURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errFetchJWKS
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
set, err := jwk.Parse(data)
if err != nil {
return nil, err
}
jwksCache.Lock()
jwksCache.jwks = set
jwksCache.cachedAt = time.Now()
jwksCache.Unlock()
return set, nil
}
func validateToken(token string, jwks jwk.Set) (jwt.Token, error) {
tkn, err := jwt.Parse(
[]byte(token),
jwt.WithValidate(true),
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true)),
)
if err != nil {
if errors.Contains(err, errJWTExpiryKey) {
return nil, smqauth.ErrExpiry
}
return nil, err
}
validator := jwt.ValidatorFunc(func(_ context.Context, t jwt.Token) jwt.ValidationError {
if t.Issuer() != issuerName {
return jwt.NewValidationError(errInvalidIssuer)
}
return nil
})
if err := jwt.Validate(tkn, jwt.WithValidator(validator)); err != nil {
return nil, errors.Wrap(errValidateJWTToken, err)
}
return tkn, nil
}
+2
View File
@@ -58,6 +58,8 @@ packages:
Cache:
Hasher:
KeyRepository:
KeyManager:
Tokenizer:
PATS:
PATSRepository:
Service: