mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
SMQ-1672 - Add asymmetric key authentication (#3228)
Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
@@ -85,3 +85,15 @@ func revokeEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return revokeKeyRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func retrieveJWKSEndpoint(svc auth.Service, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate int) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
jwks := svc.RetrieveJWKS()
|
||||
|
||||
return retrieveJWKSRes{
|
||||
Keys: jwks,
|
||||
CacheMaxAge: jwksCacheMaxAge,
|
||||
CacheStaleWhileRevalidate: jwksCacheStaleWhileRevalidate,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
package keys_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -17,30 +18,26 @@ import (
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/auth"
|
||||
httpapi "github.com/absmach/supermq/auth/api/http"
|
||||
"github.com/absmach/supermq/auth/jwt"
|
||||
"github.com/absmach/supermq/auth/mocks"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
policymocks "github.com/absmach/supermq/pkg/policies/mocks"
|
||||
"github.com/absmach/supermq/pkg/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
secret = "secret"
|
||||
contentType = "application/json"
|
||||
id = "123e4567-e89b-12d3-a456-000000000001"
|
||||
email = "user@example.com"
|
||||
loginDuration = 30 * time.Minute
|
||||
refreshDuration = 24 * time.Hour
|
||||
invalidDuration = 7 * 24 * time.Hour
|
||||
accessToken = "valid token"
|
||||
)
|
||||
|
||||
var (
|
||||
krepo *mocks.KeyRepository
|
||||
pEvaluator *policymocks.Evaluator
|
||||
)
|
||||
var Token = auth.Token{
|
||||
AccessToken: accessToken,
|
||||
}
|
||||
|
||||
type issueRequest struct {
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
@@ -72,22 +69,11 @@ func (tr testRequest) make() (*http.Response, error) {
|
||||
return tr.client.Do(req)
|
||||
}
|
||||
|
||||
func newService() auth.Service {
|
||||
krepo = new(mocks.KeyRepository)
|
||||
pRepo := new(mocks.PATSRepository)
|
||||
cache := new(mocks.Cache)
|
||||
hash := new(mocks.Hasher)
|
||||
idProvider := uuid.NewMock()
|
||||
pService := new(policymocks.Service)
|
||||
pEvaluator = new(policymocks.Evaluator)
|
||||
t := jwt.New([]byte(secret))
|
||||
func newServer() (*httptest.Server, *mocks.Service) {
|
||||
svc := new(mocks.Service)
|
||||
mux := httpapi.MakeHandler(svc, smqlog.NewMock(), "", 900, 60)
|
||||
|
||||
return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration)
|
||||
}
|
||||
|
||||
func newServer(svc auth.Service) *httptest.Server {
|
||||
mux := httpapi.MakeHandler(svc, smqlog.NewMock(), "")
|
||||
return httptest.NewServer(mux)
|
||||
return httptest.NewServer(mux), svc
|
||||
}
|
||||
|
||||
func toJSON(data any) string {
|
||||
@@ -99,13 +85,7 @@ func toJSON(data any) string {
|
||||
}
|
||||
|
||||
func TestIssue(t *testing.T) {
|
||||
svc := newService()
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
policyCall.Unset()
|
||||
|
||||
ts := newServer(svc)
|
||||
ts, svc := newServer()
|
||||
defer ts.Close()
|
||||
client := ts.Client()
|
||||
|
||||
@@ -119,6 +99,8 @@ func TestIssue(t *testing.T) {
|
||||
ct string
|
||||
token string
|
||||
status int
|
||||
svcRes auth.Token
|
||||
svcErr error
|
||||
}{
|
||||
{
|
||||
desc: "issue login key with empty token",
|
||||
@@ -131,28 +113,30 @@ func TestIssue(t *testing.T) {
|
||||
desc: "issue API key",
|
||||
req: toJSON(ak),
|
||||
ct: contentType,
|
||||
token: token.AccessToken,
|
||||
token: accessToken,
|
||||
status: http.StatusCreated,
|
||||
svcRes: Token,
|
||||
},
|
||||
{
|
||||
desc: "issue recovery key",
|
||||
req: toJSON(rk),
|
||||
ct: contentType,
|
||||
token: token.AccessToken,
|
||||
token: accessToken,
|
||||
status: http.StatusCreated,
|
||||
svcRes: Token,
|
||||
},
|
||||
{
|
||||
desc: "issue login key wrong content type",
|
||||
req: toJSON(lk),
|
||||
ct: "",
|
||||
token: token.AccessToken,
|
||||
token: accessToken,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
},
|
||||
{
|
||||
desc: "issue recovery key wrong content type",
|
||||
req: toJSON(rk),
|
||||
ct: "",
|
||||
token: token.AccessToken,
|
||||
token: accessToken,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
},
|
||||
{
|
||||
@@ -160,6 +144,7 @@ func TestIssue(t *testing.T) {
|
||||
req: toJSON(ak),
|
||||
ct: contentType,
|
||||
token: "wrong",
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
@@ -167,27 +152,28 @@ func TestIssue(t *testing.T) {
|
||||
req: toJSON(rk),
|
||||
ct: contentType,
|
||||
token: "",
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
desc: "issue key with invalid request",
|
||||
req: "{",
|
||||
ct: contentType,
|
||||
token: token.AccessToken,
|
||||
token: accessToken,
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
desc: "issue key with invalid JSON",
|
||||
req: "{invalid}",
|
||||
ct: contentType,
|
||||
token: token.AccessToken,
|
||||
token: accessToken,
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
desc: "issue key with invalid JSON content",
|
||||
req: `{"Type":{"key":"AccessToken"}}`,
|
||||
ct: contentType,
|
||||
token: token.AccessToken,
|
||||
token: accessToken,
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
@@ -201,30 +187,16 @@ func TestIssue(t *testing.T) {
|
||||
token: tc.token,
|
||||
body: strings.NewReader(tc.req),
|
||||
}
|
||||
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return("", nil)
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
svcCall := svc.On("Issue", mock.Anything, tc.token, mock.Anything).Return(tc.svcRes, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieve(t *testing.T) {
|
||||
svc := newService()
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id}
|
||||
|
||||
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil)
|
||||
k, err := svc.Issue(context.Background(), token.AccessToken, key)
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
|
||||
ts := newServer(svc)
|
||||
ts, svc := newServer()
|
||||
defer ts.Close()
|
||||
client := ts.Client()
|
||||
|
||||
@@ -234,12 +206,13 @@ func TestRetrieve(t *testing.T) {
|
||||
token string
|
||||
key auth.Key
|
||||
status int
|
||||
err error
|
||||
svcRes auth.Key
|
||||
svcErr error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve an existing key",
|
||||
id: k.AccessToken,
|
||||
token: token.AccessToken,
|
||||
id: id,
|
||||
token: accessToken,
|
||||
key: auth.Key{
|
||||
Subject: id,
|
||||
Type: auth.AccessKey,
|
||||
@@ -247,28 +220,28 @@ func TestRetrieve(t *testing.T) {
|
||||
ExpiresAt: time.Now().Add(refreshDuration),
|
||||
},
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
svcErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve a non-existing key",
|
||||
id: "non-existing",
|
||||
token: token.AccessToken,
|
||||
token: accessToken,
|
||||
status: http.StatusNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
svcErr: svcerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve a key with an invalid token",
|
||||
id: k.AccessToken,
|
||||
id: accessToken,
|
||||
token: "wrong",
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "retrieve a key with an empty token",
|
||||
token: "",
|
||||
id: k.AccessToken,
|
||||
id: accessToken,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -279,30 +252,16 @@ func TestRetrieve(t *testing.T) {
|
||||
url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id),
|
||||
token: tc.token,
|
||||
}
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
repoCall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(tc.key, tc.err)
|
||||
svcCall := svc.On("RetrieveKey", mock.Anything, tc.token, tc.id).Return(tc.svcRes, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevoke(t *testing.T) {
|
||||
svc := newService()
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id}
|
||||
|
||||
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil)
|
||||
k, err := svc.Issue(context.Background(), token.AccessToken, key)
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
|
||||
ts := newServer(svc)
|
||||
ts, svc := newServer()
|
||||
defer ts.Close()
|
||||
client := ts.Client()
|
||||
|
||||
@@ -311,29 +270,33 @@ func TestRevoke(t *testing.T) {
|
||||
id string
|
||||
token string
|
||||
status int
|
||||
svcErr error
|
||||
}{
|
||||
{
|
||||
desc: "revoke an existing key",
|
||||
id: k.AccessToken,
|
||||
token: token.AccessToken,
|
||||
id: id,
|
||||
token: accessToken,
|
||||
status: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
desc: "revoke a non-existing key",
|
||||
id: "non-existing",
|
||||
token: token.AccessToken,
|
||||
status: http.StatusNoContent,
|
||||
token: accessToken,
|
||||
svcErr: svcerr.ErrNotFound,
|
||||
status: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "revoke key with invalid token",
|
||||
id: k.AccessToken,
|
||||
id: id,
|
||||
token: "wrong",
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
desc: "revoke key with empty token",
|
||||
id: k.AccessToken,
|
||||
id: id,
|
||||
token: "",
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
@@ -345,10 +308,58 @@ func TestRevoke(t *testing.T) {
|
||||
url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id),
|
||||
token: tc.token,
|
||||
}
|
||||
repoCall := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
svcCall := svc.On("Revoke", mock.Anything, tc.token, tc.id).Return(tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
repoCall.Unset()
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveJWKS(t *testing.T) {
|
||||
ts, svc := newServer()
|
||||
defer ts.Close()
|
||||
client := ts.Client()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
svcRes []auth.JWK
|
||||
status int
|
||||
}{
|
||||
{
|
||||
desc: "retrieve JWKS with keys",
|
||||
svcRes: []auth.JWK{newJWK(t), newJWK(t)},
|
||||
status: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "retrieve empty JWKS",
|
||||
svcRes: []auth.JWK{},
|
||||
status: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := testRequest{
|
||||
client: client,
|
||||
method: http.MethodGet,
|
||||
url: fmt.Sprintf("%s/keys/.well-known/jwks.json", ts.URL),
|
||||
}
|
||||
svcCall := svc.On("RetrieveJWKS").Return(tc.svcRes)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func newJWK(t *testing.T) auth.JWK {
|
||||
pKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err))
|
||||
jwkKey, err := jwk.FromRaw(&pKey.PublicKey)
|
||||
require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err))
|
||||
err = jwkKey.Set(jwk.KeyIDKey, "test-key-id")
|
||||
require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err))
|
||||
err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String())
|
||||
require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err))
|
||||
return auth.NewJWK(jwkKey)
|
||||
}
|
||||
|
||||
@@ -46,3 +46,5 @@ func (req keyReq) validate() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type jwksReq struct{}
|
||||
|
||||
@@ -4,16 +4,21 @@
|
||||
package keys
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
)
|
||||
|
||||
var (
|
||||
_ supermq.Response = (*issueKeyRes)(nil)
|
||||
_ supermq.Response = (*revokeKeyRes)(nil)
|
||||
_ supermq.Response = (*retrieveKeyRes)(nil)
|
||||
_ supermq.Response = (*retrieveJWKSRes)(nil)
|
||||
)
|
||||
|
||||
type issueKeyRes struct {
|
||||
@@ -69,3 +74,37 @@ func (res revokeKeyRes) Headers() map[string]string {
|
||||
func (res revokeKeyRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type retrieveJWKSRes struct {
|
||||
Keys []auth.JWK `json:"-"`
|
||||
CacheMaxAge int `json:"-"`
|
||||
CacheStaleWhileRevalidate int `json:"-"`
|
||||
}
|
||||
|
||||
func (res retrieveJWKSRes) MarshalJSON() ([]byte, error) {
|
||||
set := jwk.NewSet()
|
||||
for _, k := range res.Keys {
|
||||
if err := set.AddKey(k.Key()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(set)
|
||||
}
|
||||
|
||||
func (res retrieveJWKSRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res retrieveJWKSRes) Headers() map[string]string {
|
||||
cacheControl := fmt.Sprintf("public, max-age=%d, stale-while-revalidate=%d", res.CacheMaxAge, res.CacheStaleWhileRevalidate)
|
||||
headers := map[string]string{
|
||||
"Cache-Control": cacheControl,
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func (res retrieveJWKSRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
const contentType = "application/json"
|
||||
|
||||
// MakeHandler returns a HTTP handler for API endpoints.
|
||||
func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux {
|
||||
func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate int) *chi.Mux {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
@@ -46,6 +46,13 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux {
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
).ServeHTTP)
|
||||
|
||||
r.Get("/.well-known/jwks.json", kithttp.NewServer(
|
||||
retrieveJWKSEndpoint(svc, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate),
|
||||
decodeJWKSReq,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
).ServeHTTP)
|
||||
})
|
||||
return mux
|
||||
}
|
||||
@@ -70,3 +77,8 @@ func decodeKeyReq(_ context.Context, r *http.Request) (any, error) {
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeJWKSReq(_ context.Context, _ *http.Request) (any, error) {
|
||||
req := jwksReq{}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -15,10 +15,10 @@ import (
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for API endpoints.
|
||||
func MakeHandler(svc auth.Service, logger *slog.Logger, instanceID string) http.Handler {
|
||||
func MakeHandler(svc auth.Service, logger *slog.Logger, instanceID string, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate int) http.Handler {
|
||||
mux := chi.NewRouter()
|
||||
|
||||
mux = keys.MakeHandler(svc, mux, logger)
|
||||
mux = keys.MakeHandler(svc, mux, logger, jwksCacheMaxAge, jwksCacheStaleWhileRevalidate)
|
||||
mux = pats.MakeHandler(svc, mux, logger)
|
||||
|
||||
mux.Get("/health", supermq.Health("auth", instanceID))
|
||||
|
||||
+188
-85
@@ -4,61 +4,59 @@
|
||||
package jwt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
authjwt "github.com/absmach/supermq/auth/jwt"
|
||||
"github.com/absmach/supermq/auth/mocks"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenType = "type"
|
||||
roleField = "role"
|
||||
issuerName = "supermq.auth"
|
||||
secret = "test"
|
||||
tokenType = "type"
|
||||
roleField = "role"
|
||||
VerifiedField = "verified"
|
||||
issuerName = "supermq.auth"
|
||||
)
|
||||
|
||||
var reposecret = []byte("test")
|
||||
|
||||
func newToken(issuerName string, key auth.Key) string {
|
||||
builder := jwt.NewBuilder()
|
||||
builder.
|
||||
Issuer(issuerName).
|
||||
IssuedAt(key.IssuedAt).
|
||||
Claim(tokenType, "r").
|
||||
Expiration(key.ExpiresAt)
|
||||
builder.Claim(roleField, key.Role)
|
||||
if key.Subject != "" {
|
||||
builder.Subject(key.Subject)
|
||||
}
|
||||
if key.ID != "" {
|
||||
builder.JwtID(key.ID)
|
||||
}
|
||||
tkn, _ := builder.Build()
|
||||
tokn, _ := jwt.Sign(tkn, jwt.WithKey(jwa.HS512, reposecret))
|
||||
return string(tokn)
|
||||
}
|
||||
var (
|
||||
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
keyManager = new(mocks.KeyManager)
|
||||
)
|
||||
|
||||
func TestIssue(t *testing.T) {
|
||||
tokenizer := authjwt.New([]byte(secret))
|
||||
tokenizer := authjwt.New(keyManager)
|
||||
|
||||
validKey := key()
|
||||
signedToken, _, err := signToken(issuerName, validKey, false)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
key auth.Key
|
||||
err error
|
||||
desc string
|
||||
key auth.Key
|
||||
managerReq jwt.Token
|
||||
managerResp []byte
|
||||
managerErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "issue new token",
|
||||
key: key(),
|
||||
err: nil,
|
||||
desc: "issue new token",
|
||||
key: validKey,
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token with OAuth token",
|
||||
@@ -69,7 +67,8 @@ func TestIssue(t *testing.T) {
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
|
||||
},
|
||||
err: nil,
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token without a domain",
|
||||
@@ -79,7 +78,8 @@ func TestIssue(t *testing.T) {
|
||||
Subject: testsutil.GenerateUUID(t),
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
err: nil,
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token without a subject",
|
||||
@@ -89,7 +89,8 @@ func TestIssue(t *testing.T) {
|
||||
Subject: "",
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
err: nil,
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token without type",
|
||||
@@ -99,7 +100,8 @@ func TestIssue(t *testing.T) {
|
||||
Subject: testsutil.GenerateUUID(t),
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
err: nil,
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token without a domain and subject",
|
||||
@@ -110,116 +112,161 @@ func TestIssue(t *testing.T) {
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
|
||||
},
|
||||
err: nil,
|
||||
managerResp: []byte(signedToken),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token with failed to sign jwt",
|
||||
key: validKey,
|
||||
managerErr: svcerr.ErrAuthentication,
|
||||
err: authjwt.ErrSignJWT,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tc.managerReq = newToken(issuerName, tc.key)
|
||||
kmCall := keyManager.On("SignJWT", tc.managerReq).Return(tc.managerResp, tc.managerErr)
|
||||
tkn, err := tokenizer.Issue(tc.key)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, tkn, fmt.Sprintf("%s expected token, got empty string", tc.desc))
|
||||
}
|
||||
kmCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tokenizer := authjwt.New([]byte(secret))
|
||||
tokenizer := authjwt.New(keyManager)
|
||||
|
||||
token, err := tokenizer.Issue(key())
|
||||
validKey := key()
|
||||
signedTkn, parsedTkn, err := signToken(issuerName, validKey, true)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
|
||||
|
||||
apiKey := key()
|
||||
apiKey.Type = auth.APIKey
|
||||
apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second)
|
||||
apiToken, err := tokenizer.Issue(apiKey)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
apiToken, _, err := signToken(issuerName, apiKey, false)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing api key expected to succeed: %s", err))
|
||||
|
||||
expKey := key()
|
||||
expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second)
|
||||
expToken, err := tokenizer.Issue(expKey)
|
||||
expToken, _, err := signToken(issuerName, expKey, false)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err))
|
||||
|
||||
emptySubjectKey := key()
|
||||
emptySubjectKey.Subject = ""
|
||||
emptySubjectToken, err := tokenizer.Issue(emptySubjectKey)
|
||||
signedEmptySubjectTkn, parsedEmptySubjectTkn, err := signToken(issuerName, emptySubjectKey, true)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
|
||||
emptyTypeKey := key()
|
||||
emptyTypeKey.Type = auth.KeyType(auth.InvitationKey + 1)
|
||||
emptyTypeToken, err := tokenizer.Issue(emptyTypeKey)
|
||||
emptyTypeToken, _, err := signToken(issuerName, emptyTypeKey, false)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
|
||||
emptyKey := key()
|
||||
emptyKey.Subject = ""
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
|
||||
inValidToken := newToken("invalid", key())
|
||||
signedInValidTkn, parsedInvalidTkn, err := signToken("invalid.issuer", key(), true)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
key auth.Key
|
||||
token string
|
||||
err error
|
||||
desc string
|
||||
key auth.Key
|
||||
token string
|
||||
managerRes jwt.Token
|
||||
managerErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "parse valid key",
|
||||
key: key(),
|
||||
token: token,
|
||||
err: nil,
|
||||
desc: "parse valid key",
|
||||
key: validKey,
|
||||
token: signedTkn,
|
||||
managerRes: parsedTkn,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "parse invalid key",
|
||||
key: auth.Key{},
|
||||
token: "invalid",
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "parse invalid key",
|
||||
key: auth.Key{},
|
||||
token: "invalid",
|
||||
managerErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "parse expired key",
|
||||
key: auth.Key{},
|
||||
token: expToken,
|
||||
err: auth.ErrExpiry,
|
||||
desc: "parse expired key",
|
||||
key: auth.Key{},
|
||||
token: expToken,
|
||||
managerErr: errJWTExpiryKey,
|
||||
err: auth.ErrExpiry,
|
||||
},
|
||||
{
|
||||
desc: "parse expired API key",
|
||||
key: apiKey,
|
||||
token: apiToken,
|
||||
err: auth.ErrExpiry,
|
||||
desc: "parse expired API key",
|
||||
key: apiKey,
|
||||
token: apiToken,
|
||||
managerErr: errJWTExpiryKey,
|
||||
err: auth.ErrExpiry,
|
||||
},
|
||||
{
|
||||
desc: "parse token with invalid issuer",
|
||||
key: auth.Key{},
|
||||
token: inValidToken,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "parse token with invalid issuer",
|
||||
key: auth.Key{},
|
||||
token: signedInValidTkn,
|
||||
managerRes: parsedInvalidTkn,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "parse token with invalid content",
|
||||
key: auth.Key{},
|
||||
token: newToken(issuerName, key()),
|
||||
err: authjwt.ErrJSONHandle,
|
||||
desc: "parse token with empty subject",
|
||||
key: emptySubjectKey,
|
||||
token: signedEmptySubjectTkn,
|
||||
managerRes: parsedEmptySubjectTkn,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "parse token with empty subject",
|
||||
key: emptySubjectKey,
|
||||
token: emptySubjectToken,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "parse token with empty type",
|
||||
key: emptyTypeKey,
|
||||
token: emptyTypeToken,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "parse token with empty type",
|
||||
key: emptyTypeKey,
|
||||
token: emptyTypeToken,
|
||||
managerRes: newToken(issuerName, emptyKey),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
key, err := tokenizer.Parse(tc.token)
|
||||
kmCall := keyManager.On("ParseJWT", tc.token).Return(tc.managerRes, tc.managerErr)
|
||||
key, err := tokenizer.Parse(context.Background(), tc.token)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key))
|
||||
}
|
||||
kmCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveJWKS(t *testing.T) {
|
||||
tokenizer := authjwt.New(keyManager)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
keys []auth.JWK
|
||||
retrieveErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve jwks with keys",
|
||||
keys: []auth.JWK{newJWK(t), newJWK(t)},
|
||||
},
|
||||
{
|
||||
desc: "retrieve empty jwks",
|
||||
keys: []auth.JWK{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
kmCall := keyManager.On("PublicJWKS", mock.Anything).Return(tc.keys, tc.retrieveErr)
|
||||
jwks := tokenizer.RetrieveJWKS()
|
||||
assert.Equal(t, tc.keys, jwks, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.keys, jwks))
|
||||
kmCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -236,3 +283,59 @@ func key() auth.Key {
|
||||
ExpiresAt: exp,
|
||||
}
|
||||
}
|
||||
|
||||
func newToken(issuerName string, key auth.Key) jwt.Token {
|
||||
builder := jwt.NewBuilder()
|
||||
builder.
|
||||
Issuer(issuerName).
|
||||
IssuedAt(key.IssuedAt).
|
||||
Claim(tokenType, key.Type).
|
||||
Expiration(key.ExpiresAt)
|
||||
builder.Claim(roleField, key.Role)
|
||||
builder.Claim(VerifiedField, key.Verified)
|
||||
if key.Subject != "" {
|
||||
builder.Subject(key.Subject)
|
||||
}
|
||||
if key.ID != "" {
|
||||
builder.JwtID(key.ID)
|
||||
}
|
||||
tkn, _ := builder.Build()
|
||||
return tkn
|
||||
}
|
||||
|
||||
func signToken(issuerName string, key auth.Key, parseToken bool) (string, jwt.Token, error) {
|
||||
tkn := newToken(issuerName, key)
|
||||
pKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
pubKey := &pKey.PublicKey
|
||||
sTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.RS256, pKey))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if !parseToken {
|
||||
return string(sTkn), nil, nil
|
||||
}
|
||||
pTkn, err := jwt.Parse(
|
||||
sTkn,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(jwa.RS256, pubKey),
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return string(sTkn), pTkn, nil
|
||||
}
|
||||
|
||||
func newJWK(t *testing.T) auth.JWK {
|
||||
pKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.Nil(t, err, fmt.Sprintf("generating rsa key expected to succeed: %s", err))
|
||||
jwkKey, err := jwk.FromRaw(&pKey.PublicKey)
|
||||
require.Nil(t, err, fmt.Sprintf("creating jwk from rsa public key expected to succeed: %s", err))
|
||||
err = jwkKey.Set(jwk.KeyIDKey, "test-key-id")
|
||||
require.Nil(t, err, fmt.Sprintf("setting jwk key id expected to succeed: %s", err))
|
||||
err = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256.String())
|
||||
require.Nil(t, err, fmt.Sprintf("setting jwk algorithm expected to succeed: %s", err))
|
||||
return auth.NewJWK(jwkKey)
|
||||
}
|
||||
|
||||
+18
-23
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
@@ -34,27 +33,23 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
issuerName = "supermq.auth"
|
||||
tokenType = "type"
|
||||
userField = "user"
|
||||
RoleField = "role"
|
||||
VerifiedField = "verified"
|
||||
oauthProviderField = "oauth_provider"
|
||||
oauthAccessTokenField = "access_token"
|
||||
oauthRefreshTokenField = "refresh_token"
|
||||
patPrefix = "pat"
|
||||
issuerName = "supermq.auth"
|
||||
tokenType = "type"
|
||||
RoleField = "role"
|
||||
VerifiedField = "verified"
|
||||
patPrefix = "pat"
|
||||
)
|
||||
|
||||
type tokenizer struct {
|
||||
secret []byte
|
||||
keyManager auth.KeyManager
|
||||
}
|
||||
|
||||
var _ auth.Tokenizer = (*tokenizer)(nil)
|
||||
|
||||
// NewRepository instantiates an implementation of Token repository.
|
||||
func New(secret []byte) auth.Tokenizer {
|
||||
// New instantiates an implementation of Tokenizer service.
|
||||
func New(keyManager auth.KeyManager) auth.Tokenizer {
|
||||
return &tokenizer{
|
||||
secret: secret,
|
||||
keyManager: keyManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,14 +72,14 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
signedTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.HS512, tok.secret))
|
||||
signedTkn, err := tok.keyManager.SignJWT(tkn)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(ErrSignJWT, err)
|
||||
}
|
||||
return string(signedTkn), nil
|
||||
}
|
||||
|
||||
func (tok *tokenizer) Parse(token string) (auth.Key, error) {
|
||||
func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) {
|
||||
if len(token) >= 3 && token[:3] == patPrefix {
|
||||
return auth.Key{Type: auth.PersonalAccessToken}, nil
|
||||
}
|
||||
@@ -94,7 +89,7 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) {
|
||||
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
key, err := toKey(tkn)
|
||||
key, err := ToKey(tkn)
|
||||
if err != nil {
|
||||
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
@@ -103,11 +98,7 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) {
|
||||
}
|
||||
|
||||
func (tok *tokenizer) validateToken(token string) (jwt.Token, error) {
|
||||
tkn, err := jwt.Parse(
|
||||
[]byte(token),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(jwa.HS512, tok.secret),
|
||||
)
|
||||
tkn, err := tok.keyManager.ParseJWT(token)
|
||||
if err != nil {
|
||||
if errors.Contains(err, errJWTExpiryKey) {
|
||||
return nil, auth.ErrExpiry
|
||||
@@ -128,7 +119,11 @@ func (tok *tokenizer) validateToken(token string) (jwt.Token, error) {
|
||||
return tkn, nil
|
||||
}
|
||||
|
||||
func toKey(tkn jwt.Token) (auth.Key, error) {
|
||||
func (tok *tokenizer) RetrieveJWKS() []auth.JWK {
|
||||
return tok.keyManager.PublicJWKS()
|
||||
}
|
||||
|
||||
func ToKey(tkn jwt.Token) (auth.Key, error) {
|
||||
data, err := json.Marshal(tkn.PrivateClaims())
|
||||
if err != nil {
|
||||
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm")
|
||||
ErrInvalidSymmetricKey = errors.New("invalid symmetric key")
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key.
|
||||
type JWK struct {
|
||||
key jwk.Key
|
||||
}
|
||||
|
||||
// NewJWK creates a new JWK from a jwk.Key.
|
||||
func NewJWK(key jwk.Key) JWK {
|
||||
return JWK{key: key}
|
||||
}
|
||||
|
||||
// Key returns the underlying jwk.Key.
|
||||
func (j JWK) Key() jwk.Key {
|
||||
return j.key
|
||||
}
|
||||
|
||||
// KeyManager represents a manager for JWT keys.
|
||||
type KeyManager interface {
|
||||
SignJWT(token jwt.Token) ([]byte, error)
|
||||
|
||||
ParseJWT(token string) (jwt.Token, error)
|
||||
|
||||
PublicJWKS() []JWK
|
||||
}
|
||||
|
||||
func IsSymmetricAlgorithm(alg string) (bool, error) {
|
||||
switch alg {
|
||||
case "HS256", "HS384", "HS512":
|
||||
return true, nil
|
||||
case "EdDSA":
|
||||
return false, nil
|
||||
default:
|
||||
return false, ErrUnsupportedKeyAlgorithm
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package asymmetric
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
var (
|
||||
errLoadingPrivateKey = errors.New("failed to load private key")
|
||||
errInvalidKeySize = errors.New("invalid ED25519 key size")
|
||||
errParsingPrivateKey = errors.New("failed to parse private key")
|
||||
errInvalidKeyType = errors.New("private key is not ED25519")
|
||||
errGeneratingKID = errors.New("failed to generate key ID")
|
||||
)
|
||||
|
||||
type manager struct {
|
||||
privateKey jwk.Key
|
||||
publicKey jwk.Key
|
||||
kid string
|
||||
}
|
||||
|
||||
var _ auth.KeyManager = (*manager)(nil)
|
||||
|
||||
// NewKeyManager creates a new asymmetric key manager that loads the private key from a file.
|
||||
func NewKeyManager(privateKeyPath string, idProvider supermq.IDProvider) (auth.KeyManager, error) {
|
||||
kid, err := idProvider.ID()
|
||||
if err != nil {
|
||||
return nil, errors.Join(errGeneratingKID, err)
|
||||
}
|
||||
|
||||
privateJwk, publicJwk, err := loadKeyPair(privateKeyPath, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manager{
|
||||
privateKey: privateJwk,
|
||||
publicKey: publicJwk,
|
||||
kid: kid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (km *manager) SignJWT(token jwt.Token) ([]byte, error) {
|
||||
return jwt.Sign(token, jwt.WithKey(jwa.EdDSA, km.privateKey))
|
||||
}
|
||||
|
||||
func (km *manager) ParseJWT(token string) (jwt.Token, error) {
|
||||
set := jwk.NewSet()
|
||||
if err := set.AddKey(km.publicKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tkn, err := jwt.Parse(
|
||||
[]byte(token),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(set, jws.WithInferAlgorithmFromKey(true)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tkn, nil
|
||||
}
|
||||
|
||||
func (km *manager) PublicJWKS() []auth.JWK {
|
||||
return []auth.JWK{auth.NewJWK(km.publicKey)}
|
||||
}
|
||||
|
||||
func loadKeyPair(privateKeyPath string, kid string) (jwk.Key, jwk.Key, error) {
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(errLoadingPrivateKey, err)
|
||||
}
|
||||
|
||||
var privateKey ed25519.PrivateKey
|
||||
block, _ := pem.Decode(privateKeyBytes)
|
||||
switch {
|
||||
case block != nil:
|
||||
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(errParsingPrivateKey, err)
|
||||
}
|
||||
var ok bool
|
||||
privateKey, ok = parsedKey.(ed25519.PrivateKey)
|
||||
if !ok {
|
||||
return nil, nil, errInvalidKeyType
|
||||
}
|
||||
default:
|
||||
if len(privateKeyBytes) != ed25519.PrivateKeySize {
|
||||
return nil, nil, errInvalidKeySize
|
||||
}
|
||||
privateKey = ed25519.PrivateKey(privateKeyBytes)
|
||||
}
|
||||
|
||||
publicKey := privateKey.Public().(ed25519.PublicKey)
|
||||
|
||||
privateJwk, err := jwk.FromRaw(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := privateJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := privateJwk.Set(jwk.KeyIDKey, kid); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
publicJwk, err := jwk.FromRaw(publicKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := publicJwk.Set(jwk.AlgorithmKey, jwa.EdDSA); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := publicJwk.Set(jwk.KeyIDKey, kid); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return privateJwk, publicJwk, nil
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package symmetric
|
||||
|
||||
import (
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
type manager struct {
|
||||
algorithm jwa.KeyAlgorithm
|
||||
secret []byte
|
||||
}
|
||||
|
||||
var _ auth.KeyManager = (*manager)(nil)
|
||||
|
||||
func NewKeyManager(algorithm string, secret []byte) (auth.KeyManager, error) {
|
||||
alg := jwa.KeyAlgorithmFrom(algorithm)
|
||||
if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok {
|
||||
return nil, auth.ErrUnsupportedKeyAlgorithm
|
||||
}
|
||||
if len(secret) == 0 {
|
||||
return nil, auth.ErrInvalidSymmetricKey
|
||||
}
|
||||
return &manager{
|
||||
secret: secret,
|
||||
algorithm: alg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (km *manager) SignJWT(token jwt.Token) ([]byte, error) {
|
||||
return jwt.Sign(token, jwt.WithKey(km.algorithm, km.secret))
|
||||
}
|
||||
|
||||
func (km *manager) ParseJWT(token string) (jwt.Token, error) {
|
||||
return jwt.Parse(
|
||||
[]byte(token),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(km.algorithm, km.secret),
|
||||
)
|
||||
}
|
||||
|
||||
func (km *manager) PublicJWKS() []auth.JWK {
|
||||
return nil
|
||||
}
|
||||
@@ -100,6 +100,16 @@ func (lm *loggingMiddleware) Identify(ctx context.Context, token string) (id aut
|
||||
return lm.svc.Identify(ctx, token)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.JWK) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
}
|
||||
lm.logger.Info("Retrieve JWKS completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.RetrieveJWKS()
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
|
||||
@@ -67,6 +67,14 @@ func (ms *metricsMiddleware) Identify(ctx context.Context, token string) (auth.K
|
||||
return ms.svc.Identify(ctx, token)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) RetrieveJWKS() []auth.JWK {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "retrieve_jwks").Add(1)
|
||||
ms.latency.With("method", "retrieve_jwks").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.RetrieveJWKS()
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "authorize").Add(1)
|
||||
|
||||
@@ -61,6 +61,10 @@ func (tm *tracingMiddleware) Identify(ctx context.Context, token string) (auth.K
|
||||
return tm.svc.Identify(ctx, token)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) RetrieveJWKS() []auth.JWK {
|
||||
return tm.svc.RetrieveJWKS()
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "authorize", trace.WithAttributes(
|
||||
attribute.String("subject", pr.Subject),
|
||||
|
||||
@@ -0,0 +1,212 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewKeyManager creates a new instance of KeyManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewKeyManager(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *KeyManager {
|
||||
mock := &KeyManager{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// KeyManager is an autogenerated mock type for the KeyManager type
|
||||
type KeyManager struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type KeyManager_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *KeyManager) EXPECT() *KeyManager_Expecter {
|
||||
return &KeyManager_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// ParseJWT provides a mock function for the type KeyManager
|
||||
func (_mock *KeyManager) ParseJWT(token string) (jwt.Token, error) {
|
||||
ret := _mock.Called(token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ParseJWT")
|
||||
}
|
||||
|
||||
var r0 jwt.Token
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(string) (jwt.Token, error)); ok {
|
||||
return returnFunc(token)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(string) jwt.Token); ok {
|
||||
r0 = returnFunc(token)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(jwt.Token)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = returnFunc(token)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// KeyManager_ParseJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ParseJWT'
|
||||
type KeyManager_ParseJWT_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ParseJWT is a helper method to define mock.On call
|
||||
// - token string
|
||||
func (_e *KeyManager_Expecter) ParseJWT(token interface{}) *KeyManager_ParseJWT_Call {
|
||||
return &KeyManager_ParseJWT_Call{Call: _e.mock.On("ParseJWT", token)}
|
||||
}
|
||||
|
||||
func (_c *KeyManager_ParseJWT_Call) Run(run func(token string)) *KeyManager_ParseJWT_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 string
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_ParseJWT_Call) Return(token1 jwt.Token, err error) *KeyManager_ParseJWT_Call {
|
||||
_c.Call.Return(token1, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_ParseJWT_Call) RunAndReturn(run func(token string) (jwt.Token, error)) *KeyManager_ParseJWT_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// PublicJWKS provides a mock function for the type KeyManager
|
||||
func (_mock *KeyManager) PublicJWKS() []auth.JWK {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for PublicJWKS")
|
||||
}
|
||||
|
||||
var r0 []auth.JWK
|
||||
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]auth.JWK)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// KeyManager_PublicJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublicJWKS'
|
||||
type KeyManager_PublicJWKS_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// PublicJWKS is a helper method to define mock.On call
|
||||
func (_e *KeyManager_Expecter) PublicJWKS() *KeyManager_PublicJWKS_Call {
|
||||
return &KeyManager_PublicJWKS_Call{Call: _e.mock.On("PublicJWKS")}
|
||||
}
|
||||
|
||||
func (_c *KeyManager_PublicJWKS_Call) Run(run func()) *KeyManager_PublicJWKS_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_PublicJWKS_Call) Return(jWKs []auth.JWK) *KeyManager_PublicJWKS_Call {
|
||||
_c.Call.Return(jWKs)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_PublicJWKS_Call) RunAndReturn(run func() []auth.JWK) *KeyManager_PublicJWKS_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SignJWT provides a mock function for the type KeyManager
|
||||
func (_mock *KeyManager) SignJWT(token jwt.Token) ([]byte, error) {
|
||||
ret := _mock.Called(token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SignJWT")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(jwt.Token) ([]byte, error)); ok {
|
||||
return returnFunc(token)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(jwt.Token) []byte); ok {
|
||||
r0 = returnFunc(token)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(jwt.Token) error); ok {
|
||||
r1 = returnFunc(token)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// KeyManager_SignJWT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignJWT'
|
||||
type KeyManager_SignJWT_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SignJWT is a helper method to define mock.On call
|
||||
// - token jwt.Token
|
||||
func (_e *KeyManager_Expecter) SignJWT(token interface{}) *KeyManager_SignJWT_Call {
|
||||
return &KeyManager_SignJWT_Call{Call: _e.mock.On("SignJWT", token)}
|
||||
}
|
||||
|
||||
func (_c *KeyManager_SignJWT_Call) Run(run func(token jwt.Token)) *KeyManager_SignJWT_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 jwt.Token
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(jwt.Token)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_SignJWT_Call) Return(bytes []byte, err error) *KeyManager_SignJWT_Call {
|
||||
_c.Call.Return(bytes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *KeyManager_SignJWT_Call) RunAndReturn(run func(token jwt.Token) ([]byte, error)) *KeyManager_SignJWT_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -1028,6 +1028,52 @@ func (_c *Service_ResetPATSecret_Call) RunAndReturn(run func(ctx context.Context
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveJWKS provides a mock function for the type Service
|
||||
func (_mock *Service) RetrieveJWKS() []auth.JWK {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveJWKS")
|
||||
}
|
||||
|
||||
var r0 []auth.JWK
|
||||
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]auth.JWK)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_RetrieveJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveJWKS'
|
||||
type Service_RetrieveJWKS_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RetrieveJWKS is a helper method to define mock.On call
|
||||
func (_e *Service_Expecter) RetrieveJWKS() *Service_RetrieveJWKS_Call {
|
||||
return &Service_RetrieveJWKS_Call{Call: _e.mock.On("RetrieveJWKS")}
|
||||
}
|
||||
|
||||
func (_c *Service_RetrieveJWKS_Call) Run(run func()) *Service_RetrieveJWKS_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Service_RetrieveJWKS_Call {
|
||||
_c.Call.Return(jWKs)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Service_RetrieveJWKS_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveKey provides a mock function for the type Service
|
||||
func (_mock *Service) RetrieveKey(ctx context.Context, token string, id string) (auth.Key, error) {
|
||||
ret := _mock.Called(ctx, token, id)
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewTokenizer creates a new instance of Tokenizer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewTokenizer(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Tokenizer {
|
||||
mock := &Tokenizer{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Tokenizer is an autogenerated mock type for the Tokenizer type
|
||||
type Tokenizer struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Tokenizer_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Tokenizer) EXPECT() *Tokenizer_Expecter {
|
||||
return &Tokenizer_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Issue provides a mock function for the type Tokenizer
|
||||
func (_mock *Tokenizer) Issue(key auth.Key) (string, error) {
|
||||
ret := _mock.Called(key)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Issue")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(auth.Key) (string, error)); ok {
|
||||
return returnFunc(key)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(auth.Key) string); ok {
|
||||
r0 = returnFunc(key)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(auth.Key) error); ok {
|
||||
r1 = returnFunc(key)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Tokenizer_Issue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Issue'
|
||||
type Tokenizer_Issue_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Issue is a helper method to define mock.On call
|
||||
// - key auth.Key
|
||||
func (_e *Tokenizer_Expecter) Issue(key interface{}) *Tokenizer_Issue_Call {
|
||||
return &Tokenizer_Issue_Call{Call: _e.mock.On("Issue", key)}
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_Issue_Call) Run(run func(key auth.Key)) *Tokenizer_Issue_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 auth.Key
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(auth.Key)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_Issue_Call) Return(token string, err error) *Tokenizer_Issue_Call {
|
||||
_c.Call.Return(token, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_Issue_Call) RunAndReturn(run func(key auth.Key) (string, error)) *Tokenizer_Issue_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Parse provides a mock function for the type Tokenizer
|
||||
func (_mock *Tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) {
|
||||
ret := _mock.Called(ctx, token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Parse")
|
||||
}
|
||||
|
||||
var r0 auth.Key
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (auth.Key, error)); ok {
|
||||
return returnFunc(ctx, token)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) auth.Key); ok {
|
||||
r0 = returnFunc(ctx, token)
|
||||
} else {
|
||||
r0 = ret.Get(0).(auth.Key)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = returnFunc(ctx, token)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Tokenizer_Parse_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Parse'
|
||||
type Tokenizer_Parse_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Parse is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - token string
|
||||
func (_e *Tokenizer_Expecter) Parse(ctx interface{}, token interface{}) *Tokenizer_Parse_Call {
|
||||
return &Tokenizer_Parse_Call{Call: _e.mock.On("Parse", ctx, token)}
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_Parse_Call) Run(run func(ctx context.Context, token string)) *Tokenizer_Parse_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_Parse_Call) Return(key auth.Key, err error) *Tokenizer_Parse_Call {
|
||||
_c.Call.Return(key, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_Parse_Call) RunAndReturn(run func(ctx context.Context, token string) (auth.Key, error)) *Tokenizer_Parse_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveJWKS provides a mock function for the type Tokenizer
|
||||
func (_mock *Tokenizer) RetrieveJWKS() []auth.JWK {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveJWKS")
|
||||
}
|
||||
|
||||
var r0 []auth.JWK
|
||||
if returnFunc, ok := ret.Get(0).(func() []auth.JWK); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]auth.JWK)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Tokenizer_RetrieveJWKS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveJWKS'
|
||||
type Tokenizer_RetrieveJWKS_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RetrieveJWKS is a helper method to define mock.On call
|
||||
func (_e *Tokenizer_Expecter) RetrieveJWKS() *Tokenizer_RetrieveJWKS_Call {
|
||||
return &Tokenizer_RetrieveJWKS_Call{Call: _e.mock.On("RetrieveJWKS")}
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_RetrieveJWKS_Call) Run(run func()) *Tokenizer_RetrieveJWKS_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_RetrieveJWKS_Call) Return(jWKs []auth.JWK) *Tokenizer_RetrieveJWKS_Call {
|
||||
_c.Call.Return(jWKs)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() []auth.JWK) *Tokenizer_RetrieveJWKS_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
+14
-7
@@ -83,6 +83,9 @@ type Authn interface {
|
||||
// is returned. If token is invalid, or invocation failed for some
|
||||
// other reason, non-nil error value is returned in response.
|
||||
Identify(ctx context.Context, token string) (Key, error)
|
||||
|
||||
// RetrieveJWKS retrieves a JWKs to validate issued tokens.
|
||||
RetrieveJWKS() []JWK
|
||||
}
|
||||
|
||||
// Service specifies an API that must be fulfilled by the domain service
|
||||
@@ -145,7 +148,7 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err
|
||||
}
|
||||
|
||||
func (svc service) Revoke(ctx context.Context, token, id string) error {
|
||||
issuerID, _, err := svc.authenticate(token)
|
||||
issuerID, _, err := svc.authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRevoke, err)
|
||||
}
|
||||
@@ -156,7 +159,7 @@ func (svc service) Revoke(ctx context.Context, token, id string) error {
|
||||
}
|
||||
|
||||
func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, error) {
|
||||
issuerID, _, err := svc.authenticate(token)
|
||||
issuerID, _, err := svc.authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return Key{}, errors.Wrap(errRetrieve, err)
|
||||
}
|
||||
@@ -169,7 +172,7 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro
|
||||
}
|
||||
|
||||
func (svc service) Identify(ctx context.Context, token string) (Key, error) {
|
||||
key, err := svc.tokenizer.Parse(token)
|
||||
key, err := svc.tokenizer.Parse(ctx, token)
|
||||
if errors.Contains(err, ErrExpiry) {
|
||||
err = svc.keys.Remove(ctx, key.Issuer, key.ID)
|
||||
return Key{}, errors.Wrap(svcerr.ErrAuthentication, errors.Wrap(ErrKeyExpired, err))
|
||||
@@ -198,6 +201,10 @@ func (svc service) Identify(ctx context.Context, token string) (Key, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (svc service) RetrieveJWKS() []JWK {
|
||||
return svc.tokenizer.RetrieveJWKS()
|
||||
}
|
||||
|
||||
func (svc service) Authorize(ctx context.Context, pr policies.Policy) error {
|
||||
if pr.PatID != "" && pr.TokenType == PersonalAccessTokenType {
|
||||
if err := svc.AuthorizePAT(ctx, pr.UserID, pr.PatID, EntityType(pr.EntityType), pr.OptionalDomainID, Operation(pr.Operation), pr.EntityID); err != nil {
|
||||
@@ -325,7 +332,7 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) {
|
||||
}
|
||||
|
||||
func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token, error) {
|
||||
k, err := svc.tokenizer.Parse(token)
|
||||
k, err := svc.tokenizer.Parse(ctx, token)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(errRetrieve, err)
|
||||
}
|
||||
@@ -402,7 +409,7 @@ func (svc service) getUserRole(ctx context.Context, userID string) (role Role) {
|
||||
}
|
||||
|
||||
func (svc service) userKey(ctx context.Context, token string, key Key) (Token, error) {
|
||||
id, sub, err := svc.authenticate(token)
|
||||
id, sub, err := svc.authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(errIssueUser, err)
|
||||
}
|
||||
@@ -433,8 +440,8 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e
|
||||
return Token{AccessToken: tkn}, nil
|
||||
}
|
||||
|
||||
func (svc service) authenticate(token string) (string, string, error) {
|
||||
key, err := svc.tokenizer.Parse(token)
|
||||
func (svc service) authenticate(ctx context.Context, token string) (string, string, error) {
|
||||
key, err := svc.tokenizer.Parse(ctx, token)
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
+426
-262
File diff suppressed because it is too large
Load Diff
+8
-1
@@ -3,11 +3,18 @@
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Tokenizer specifies API for encoding and decoding between string and Key.
|
||||
type Tokenizer interface {
|
||||
// Issue converts API Key to its string representation.
|
||||
Issue(key Key) (token string, err error)
|
||||
|
||||
// Parse extracts API Key data from string token.
|
||||
Parse(token string) (key Key, err error)
|
||||
Parse(ctx context.Context, token string) (key Key, err error)
|
||||
|
||||
// RetrieveJWKS returns the JSON Web Key Set.
|
||||
RetrieveJWKS() []JWK
|
||||
}
|
||||
|
||||
+54
-22
@@ -23,6 +23,8 @@ import (
|
||||
"github.com/absmach/supermq/auth/cache"
|
||||
"github.com/absmach/supermq/auth/hasher"
|
||||
"github.com/absmach/supermq/auth/jwt"
|
||||
"github.com/absmach/supermq/auth/keymanager/asymmetric"
|
||||
"github.com/absmach/supermq/auth/keymanager/symmetric"
|
||||
"github.com/absmach/supermq/auth/middleware"
|
||||
apostgres "github.com/absmach/supermq/auth/postgres"
|
||||
redisclient "github.com/absmach/supermq/internal/clients/redis"
|
||||
@@ -59,22 +61,26 @@ const (
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"SMQ_AUTH_LOG_LEVEL" envDefault:"info"`
|
||||
SecretKey string `env:"SMQ_AUTH_SECRET_KEY" envDefault:"secret"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_AUTH_ADAPTER_INSTANCE_ID" envDefault:""`
|
||||
AccessDuration time.Duration `env:"SMQ_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"`
|
||||
RefreshDuration time.Duration `env:"SMQ_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"`
|
||||
InvitationDuration time.Duration `env:"SMQ_AUTH_INVITATION_DURATION" envDefault:"168h"`
|
||||
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"./docker/spicedb/schema.zed"`
|
||||
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
CacheURL string `env:"SMQ_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"`
|
||||
CacheKeyDuration time.Duration `env:"SMQ_AUTH_CACHE_KEY_DURATION" envDefault:"10m"`
|
||||
LogLevel string `env:"SMQ_AUTH_LOG_LEVEL" envDefault:"info"`
|
||||
SecretKey string `env:"SMQ_AUTH_SECRET_KEY" envDefault:"secret"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_AUTH_ADAPTER_INSTANCE_ID" envDefault:""`
|
||||
AccessDuration time.Duration `env:"SMQ_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"`
|
||||
RefreshDuration time.Duration `env:"SMQ_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"`
|
||||
KeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"EdDSA"`
|
||||
PrivateKeyPath string `env:"SMQ_AUTH_KEYS_PRIVATE_KEY_PATH" envDefault:"./ssl/keys/private.pem"`
|
||||
InvitationDuration time.Duration `env:"SMQ_AUTH_INVITATION_DURATION" envDefault:"168h"`
|
||||
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"./docker/spicedb/schema.zed"`
|
||||
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
CacheURL string `env:"SMQ_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"`
|
||||
CacheKeyDuration time.Duration `env:"SMQ_AUTH_CACHE_KEY_DURATION" envDefault:"10m"`
|
||||
JWKSCacheMaxAge int `env:"SMQ_AUTH_JWKS_CACHE_MAX_AGE" envDefault:"900"`
|
||||
JWKSCacheStaleWhileRevalidate int `env:"SMQ_AUTH_JWKS_CACHE_STALE_WHILE_REVALIDATE" envDefault:"60"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -144,7 +150,34 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration)
|
||||
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.KeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to determine key algorithm type: %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
idProvider := uuid.New()
|
||||
|
||||
var keyManager auth.KeyManager
|
||||
switch {
|
||||
case isSymmetric:
|
||||
keyManager, err = symmetric.NewKeyManager(cfg.KeyAlgorithm, []byte(cfg.SecretKey))
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create symmetric key manager: %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
default:
|
||||
keyManager, err = asymmetric.NewKeyManager(cfg.PrivateKeyPath, idProvider)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create asymmetric key manager: %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
svc, err := newService(db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration, keyManager, idProvider)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create service : %s\n", err.Error()))
|
||||
exitCode = 1
|
||||
@@ -180,7 +213,7 @@ func main() {
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, logger, cfg.InstanceID), logger)
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, logger, cfg.InstanceID, cfg.JWKSCacheMaxAge, cfg.JWKSCacheStaleWhileRevalidate), logger)
|
||||
|
||||
g.Go(func() error {
|
||||
return hs.Start()
|
||||
@@ -225,21 +258,20 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch
|
||||
return nil
|
||||
}
|
||||
|
||||
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration) (auth.Service, error) {
|
||||
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, keyManager auth.KeyManager, idProvider supermq.IDProvider) (auth.Service, error) {
|
||||
cache := cache.NewPatsCache(cacheClient, keyDuration)
|
||||
|
||||
database := pgclient.NewDatabase(db, dbConfig, tracer)
|
||||
keysRepo := apostgres.New(database)
|
||||
patsRepo := apostgres.NewPatRepo(database, cache)
|
||||
hasher := hasher.New()
|
||||
idProvider := uuid.New()
|
||||
|
||||
pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger)
|
||||
pService := spicedb.NewPolicyService(spicedbClient, logger)
|
||||
|
||||
t := jwt.New([]byte(cfg.SecretKey))
|
||||
tokenizer := jwt.New(keyManager)
|
||||
|
||||
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
|
||||
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
|
||||
svc = middleware.NewLogging(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics("auth", "api")
|
||||
svc = middleware.NewMetrics(svc, counter, latency)
|
||||
|
||||
+29
-4
@@ -18,6 +18,7 @@ import (
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/channels"
|
||||
grpcapi "github.com/absmach/supermq/channels/api/grpc"
|
||||
httpapi "github.com/absmach/supermq/channels/api/http"
|
||||
@@ -35,6 +36,7 @@ import (
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
"github.com/absmach/supermq/pkg/callout"
|
||||
@@ -98,6 +100,8 @@ type config struct {
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
|
||||
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"`
|
||||
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
|
||||
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
|
||||
PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"`
|
||||
}
|
||||
|
||||
@@ -176,14 +180,35 @@ func main() {
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
authn, authnClient, err := authsvcAuthn.NewAuthentication(ctx, grpcCfg)
|
||||
|
||||
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
var authn smqauthn.Authentication
|
||||
var authnClient grpcclient.Handler
|
||||
switch {
|
||||
case !isSymmetric:
|
||||
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, grpcCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
|
||||
default:
|
||||
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, grpcCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
}
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
|
||||
+30
-5
@@ -18,6 +18,7 @@ import (
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/clients"
|
||||
grpcapi "github.com/absmach/supermq/clients/api/grpc"
|
||||
httpapi "github.com/absmach/supermq/clients/api/http"
|
||||
@@ -34,6 +35,7 @@ import (
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
"github.com/absmach/supermq/pkg/callout"
|
||||
@@ -99,6 +101,8 @@ type config struct {
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
|
||||
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"`
|
||||
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
|
||||
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
|
||||
PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"`
|
||||
}
|
||||
|
||||
@@ -186,14 +190,35 @@ func main() {
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
authn, authnClient, err := authsvcAuthn.NewAuthentication(ctx, grpcCfg)
|
||||
|
||||
alg, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
var authn smqauthn.Authentication
|
||||
var authnClient grpcclient.Handler
|
||||
switch {
|
||||
case !alg:
|
||||
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, grpcCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
|
||||
default:
|
||||
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, grpcCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
}
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
@@ -224,7 +249,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
defer authzClient.Close()
|
||||
logger.Info("AuthZ successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
logger.Info("AuthZ successfully connected to auth gRPC server " + authzClient.Secure())
|
||||
|
||||
chgrpccfg := grpcclient.Config{}
|
||||
if err := env.ParseWithOptions(&chgrpccfg, env.Options{Prefix: envPrefixChannels}); err != nil {
|
||||
|
||||
+28
-4
@@ -15,6 +15,7 @@ import (
|
||||
chclient "github.com/absmach/callhome/pkg/client"
|
||||
"github.com/absmach/supermq"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/domains"
|
||||
domainsSvc "github.com/absmach/supermq/domains"
|
||||
domainsgrpcapi "github.com/absmach/supermq/domains/api/grpc"
|
||||
@@ -28,6 +29,7 @@ import (
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
|
||||
"github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
"github.com/absmach/supermq/pkg/callout"
|
||||
@@ -83,6 +85,8 @@ type config struct {
|
||||
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
|
||||
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
|
||||
PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"`
|
||||
}
|
||||
|
||||
@@ -153,14 +157,34 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, clientConfig)
|
||||
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("authn failed to connect to auth gRPC server : %s", err.Error()))
|
||||
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
var authn smqauthn.Authentication
|
||||
var authnClient grpcclient.Handler
|
||||
switch {
|
||||
case !isSymmetric:
|
||||
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, clientConfig)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
|
||||
default:
|
||||
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, clientConfig)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
}
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
database := postgres.NewDatabase(db, dbConfig, tracer)
|
||||
|
||||
+28
-4
@@ -17,6 +17,7 @@ import (
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/domains"
|
||||
dpostgres "github.com/absmach/supermq/domains/postgres"
|
||||
"github.com/absmach/supermq/groups"
|
||||
@@ -30,6 +31,7 @@ import (
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
"github.com/absmach/supermq/pkg/callout"
|
||||
@@ -89,6 +91,8 @@ type config struct {
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"`
|
||||
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
|
||||
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
|
||||
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
|
||||
PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"`
|
||||
}
|
||||
|
||||
@@ -157,14 +161,34 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientConfig)
|
||||
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error("failed to create authn " + err.Error())
|
||||
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
var authn smqauthn.Authentication
|
||||
var authnClient grpcclient.Handler
|
||||
switch {
|
||||
case !isSymmetric:
|
||||
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authClientConfig)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
|
||||
default:
|
||||
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authClientConfig)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
}
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
|
||||
+36
-12
@@ -22,11 +22,13 @@ import (
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
"github.com/absmach/supermq/auth"
|
||||
adapter "github.com/absmach/supermq/http"
|
||||
httpapi "github.com/absmach/supermq/http/api"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
"github.com/absmach/supermq/pkg/grpcclient"
|
||||
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
|
||||
@@ -60,13 +62,15 @@ const (
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"SMQ_HTTP_ADAPTER_LOG_LEVEL" envDefault:"info"`
|
||||
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_HTTP_ADAPTER_INSTANCE_ID" envDefault:""`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
LogLevel string `env:"SMQ_HTTP_ADAPTER_LOG_LEVEL" envDefault:"info"`
|
||||
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_HTTP_ADAPTER_INSTANCE_ID" envDefault:""`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
|
||||
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -163,14 +167,34 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
authn, authnHandler, err := authsvc.NewAuthentication(ctx, authnCfg)
|
||||
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
var authn smqauthn.Authentication
|
||||
var authnClient grpcclient.Handler
|
||||
switch {
|
||||
case !isSymmetric:
|
||||
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authnCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
|
||||
default:
|
||||
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authnCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
}
|
||||
|
||||
tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
|
||||
if err != nil {
|
||||
|
||||
+34
-10
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
chclient "github.com/absmach/callhome/pkg/client"
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/journal"
|
||||
httpapi "github.com/absmach/supermq/journal/api"
|
||||
"github.com/absmach/supermq/journal/events"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
@@ -51,12 +53,14 @@ const (
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"SMQ_JOURNAL_LOG_LEVEL" envDefault:"info"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_JOURNAL_INSTANCE_ID" envDefault:""`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
LogLevel string `env:"SMQ_JOURNAL_LOG_LEVEL" envDefault:"info"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_JOURNAL_INSTANCE_ID" envDefault:""`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
|
||||
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -105,14 +109,34 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientCfg)
|
||||
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
var authn smqauthn.Authentication
|
||||
var authnClient grpcclient.Handler
|
||||
switch {
|
||||
case !isSymmetric:
|
||||
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authClientCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
|
||||
default:
|
||||
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authClientCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
}
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
|
||||
+28
-6
@@ -19,10 +19,12 @@ import (
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
|
||||
grpcUsersV1 "github.com/absmach/supermq/api/grpc/users/v1"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/internal/email"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
@@ -97,6 +99,8 @@ type config struct {
|
||||
PasswordResetEmailTemplate string `env:"SMQ_PASSWORD_RESET_EMAIL_TEMPLATE" envDefault:"reset-password-email.tmpl"`
|
||||
VerificationURLPrefix string `env:"SMQ_VERIFICATION_URL_PREFIX" envDefault:"http://localhost/verify-email"`
|
||||
VerificationEmailTemplate string `env:"SMQ_VERIFICATION_EMAIL_TEMPLATE" envDefault:"verification-email.tmpl"`
|
||||
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
|
||||
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
|
||||
PassRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
@@ -194,17 +198,35 @@ func main() {
|
||||
defer tokenHandler.Close()
|
||||
logger.Info("Token service client successfully connected to auth gRPC server " + tokenHandler.Secure())
|
||||
|
||||
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientConfig)
|
||||
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error("failed to create authn " + err.Error())
|
||||
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
|
||||
var authn smqauthn.Authentication
|
||||
var authnClient grpcclient.Handler
|
||||
switch {
|
||||
case !isSymmetric:
|
||||
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authClientConfig)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
|
||||
default:
|
||||
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authClientConfig)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
}
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err))
|
||||
|
||||
+36
-12
@@ -19,9 +19,11 @@ import (
|
||||
"github.com/absmach/supermq"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
"github.com/absmach/supermq/auth"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
"github.com/absmach/supermq/pkg/grpcclient"
|
||||
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
|
||||
@@ -56,13 +58,15 @@ const (
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"SMQ_WS_ADAPTER_LOG_LEVEL" envDefault:"info"`
|
||||
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_WS_ADAPTER_INSTANCE_ID" envDefault:""`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
LogLevel string `env:"SMQ_WS_ADAPTER_LOG_LEVEL" envDefault:"info"`
|
||||
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_WS_ADAPTER_INSTANCE_ID" envDefault:""`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
|
||||
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -158,14 +162,34 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
authn, authnHandler, err := authsvc.NewAuthentication(ctx, authnCfg)
|
||||
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
var authn authn.Authentication
|
||||
var authnClient grpcclient.Handler
|
||||
switch {
|
||||
case !isSymmetric:
|
||||
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authnCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
|
||||
default:
|
||||
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authnCfg)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
}
|
||||
|
||||
tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
|
||||
if err != nil {
|
||||
|
||||
+5
-1
@@ -98,13 +98,17 @@ SMQ_AUTH_DB_SSL_MODE=disable
|
||||
SMQ_AUTH_DB_SSL_CERT=
|
||||
SMQ_AUTH_DB_SSL_KEY=
|
||||
SMQ_AUTH_DB_SSL_ROOT_CERT=
|
||||
SMQ_AUTH_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH
|
||||
SMQ_AUTH_ACCESS_TOKEN_DURATION="1h"
|
||||
SMQ_AUTH_REFRESH_TOKEN_DURATION="24h"
|
||||
SMQ_AUTH_KEYS_ALGORITHM="EdDSA"
|
||||
SMQ_AUTH_KEYS_PRIVATE_KEY_PATH="./ssl/keys/private.key"
|
||||
SMQ_AUTH_INVITATION_DURATION="168h"
|
||||
SMQ_AUTH_ADAPTER_INSTANCE_ID=
|
||||
SMQ_AUTH_CACHE_URL=redis://auth-redis:${SMQ_REDIS_TCP_PORT}/0
|
||||
SMQ_AUTH_CACHE_KEY_DURATION=10m
|
||||
SMQ_AUTH_JWKS_URL=http://${SMQ_AUTH_HTTP_HOST}:${SMQ_AUTH_HTTP_PORT}/keys/.well-known/jwks.json
|
||||
SMQ_AUTH_JWKS_CACHE_MAX_AGE=900
|
||||
SMQ_AUTH_JWKS_CACHE_STALE_WHILE_REVALIDATE=60
|
||||
|
||||
#### Client Callout
|
||||
SMQ_CLIENTS_CALLOUT_URLS=""
|
||||
|
||||
@@ -23,6 +23,7 @@ volumes:
|
||||
supermq-domains-db-volume:
|
||||
supermq-domains-redis-volume:
|
||||
supermq-auth-redis-volume:
|
||||
supermq-auth-keys-volume:
|
||||
|
||||
services:
|
||||
spicedb:
|
||||
@@ -110,16 +111,17 @@ services:
|
||||
SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY}
|
||||
SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST}
|
||||
SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT}
|
||||
SMQ_AUTH_ACCESS_TOKEN_DURATION: ${SMQ_AUTH_ACCESS_TOKEN_DURATION}
|
||||
SMQ_AUTH_REFRESH_TOKEN_DURATION: ${SMQ_AUTH_REFRESH_TOKEN_DURATION}
|
||||
SMQ_AUTH_INVITATION_DURATION: ${SMQ_AUTH_INVITATION_DURATION}
|
||||
SMQ_AUTH_SECRET_KEY: ${SMQ_AUTH_SECRET_KEY}
|
||||
SMQ_AUTH_HTTP_HOST: ${SMQ_AUTH_HTTP_HOST}
|
||||
SMQ_AUTH_HTTP_PORT: ${SMQ_AUTH_HTTP_PORT}
|
||||
SMQ_AUTH_HTTP_SERVER_CERT: ${SMQ_AUTH_HTTP_SERVER_CERT}
|
||||
SMQ_AUTH_HTTP_SERVER_KEY: ${SMQ_AUTH_HTTP_SERVER_KEY}
|
||||
SMQ_AUTH_GRPC_HOST: ${SMQ_AUTH_GRPC_HOST}
|
||||
SMQ_AUTH_GRPC_PORT: ${SMQ_AUTH_GRPC_PORT}
|
||||
SMQ_AUTH_ACCESS_TOKEN_DURATION: ${SMQ_AUTH_ACCESS_TOKEN_DURATION}
|
||||
SMQ_AUTH_REFRESH_TOKEN_DURATION: ${SMQ_AUTH_REFRESH_TOKEN_DURATION}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_AUTH_KEYS_PRIVATE_KEY_PATH: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:+/keys/private.key}
|
||||
## Compose supports parameter expansion in environment,
|
||||
## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty
|
||||
## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default
|
||||
@@ -150,6 +152,13 @@ services:
|
||||
volumes:
|
||||
- ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE}
|
||||
- supermq-pat-db-volume:/supermq-data
|
||||
- supermq-auth-keys-volume:/keys
|
||||
# Auth private key file
|
||||
- type: bind
|
||||
source: ${SMQ_AUTH_KEYS_PRIVATE_KEY_PATH:-ssl/certs/dummy/private_key}
|
||||
target: /keys/private.key
|
||||
bind:
|
||||
create_host_path: true
|
||||
# Auth gRPC mTLS server certificates
|
||||
- type: bind
|
||||
source: ${SMQ_AUTH_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert}
|
||||
@@ -258,6 +267,7 @@ services:
|
||||
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
|
||||
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
|
||||
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_GROUPS_GRPC_URL: ${SMQ_GROUPS_GRPC_URL}
|
||||
SMQ_GROUPS_GRPC_TIMEOUT: ${SMQ_GROUPS_GRPC_TIMEOUT}
|
||||
SMQ_GROUPS_GRPC_CLIENT_CERT: ${SMQ_GROUPS_GRPC_CLIENT_CERT:+/groups-grpc-client.crt}
|
||||
@@ -490,6 +500,7 @@ services:
|
||||
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
|
||||
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
|
||||
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_CHANNELS_URL: ${SMQ_CHANNELS_URL}
|
||||
SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL}
|
||||
SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT}
|
||||
@@ -683,6 +694,7 @@ services:
|
||||
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
|
||||
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
|
||||
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL}
|
||||
SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT}
|
||||
SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt}
|
||||
@@ -883,6 +895,7 @@ services:
|
||||
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
|
||||
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
|
||||
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL}
|
||||
SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT}
|
||||
SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt}
|
||||
@@ -1069,6 +1082,7 @@ services:
|
||||
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
|
||||
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
|
||||
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY}
|
||||
SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST}
|
||||
SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT}
|
||||
@@ -1323,6 +1337,7 @@ services:
|
||||
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
|
||||
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
|
||||
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL}
|
||||
SMQ_JAEGER_URL: ${SMQ_JAEGER_URL}
|
||||
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
|
||||
@@ -1553,6 +1568,7 @@ services:
|
||||
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
|
||||
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
|
||||
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
|
||||
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
|
||||
SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL}
|
||||
SMQ_JAEGER_URL: ${SMQ_JAEGER_URL}
|
||||
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEIPB+6hA+8rK067SdVlkWzgtxEUNysMhFFFzmGsKB1BAl
|
||||
-----END PRIVATE KEY-----
|
||||
@@ -19,6 +19,8 @@ const (
|
||||
PersonalAccessToken
|
||||
)
|
||||
|
||||
const PatPrefix = "pat_"
|
||||
|
||||
func (t TokenType) String() string {
|
||||
switch t {
|
||||
case AccessToken:
|
||||
|
||||
@@ -15,8 +15,6 @@ import (
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
const patPrefix = "pat_"
|
||||
|
||||
type authentication struct {
|
||||
authSvcClient grpcAuthV1.AuthServiceClient
|
||||
}
|
||||
@@ -46,7 +44,7 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
|
||||
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(token, patPrefix) {
|
||||
if strings.HasPrefix(token, authn.PatPrefix) {
|
||||
return authn.Session{Type: authn.PersonalAccessToken, PatID: res.GetId(), UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jwks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
grpcAuthV1 "github.com/absmach/supermq/api/grpc/auth/v1"
|
||||
smqauth "github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/auth/api/grpc/auth"
|
||||
smqjwt "github.com/absmach/supermq/auth/jwt"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/grpcclient"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
const (
|
||||
issuerName = "supermq.auth"
|
||||
cacheDuration = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
// errJWTExpiryKey is used to check if the token is expired.
|
||||
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
// errFetchJWKS indicates an error fetching JWKS from URL.
|
||||
errFetchJWKS = errors.New("failed to fetch jwks")
|
||||
// errInvalidIssuer indicates an invalid issuer value.
|
||||
errInvalidIssuer = errors.New("invalid token issuer value")
|
||||
// ErrValidateJWTToken indicates a failure to validate JWT token.
|
||||
errValidateJWTToken = errors.New("failed to validate jwt token")
|
||||
|
||||
jwksCache = struct {
|
||||
sync.RWMutex
|
||||
jwks jwk.Set
|
||||
cachedAt time.Time
|
||||
}{}
|
||||
)
|
||||
|
||||
var _ authn.Authentication = (*authentication)(nil)
|
||||
|
||||
type authentication struct {
|
||||
jwksURL string
|
||||
authSvcClient grpcAuthV1.AuthServiceClient
|
||||
}
|
||||
|
||||
func NewAuthentication(ctx context.Context, jwksURL string, cfg grpcclient.Config) (authn.Authentication, grpcclient.Handler, error) {
|
||||
client, err := grpcclient.NewHandler(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "auth",
|
||||
})
|
||||
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
|
||||
return nil, nil, grpcclient.ErrSvcNotServing
|
||||
}
|
||||
authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout)
|
||||
|
||||
return authentication{
|
||||
jwksURL: jwksURL,
|
||||
authSvcClient: authSvcClient,
|
||||
}, client, nil
|
||||
}
|
||||
|
||||
func (a authentication) Authenticate(ctx context.Context, token string) (authn.Session, error) {
|
||||
if strings.HasPrefix(token, authn.PatPrefix) {
|
||||
res, err := a.authSvcClient.Authenticate(ctx, &grpcAuthV1.AuthNReq{Token: token})
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
return authn.Session{Type: authn.PersonalAccessToken, PatID: res.GetId(), UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
|
||||
}
|
||||
|
||||
jwks, err := a.fetchJWKS(ctx)
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
tkn, err := validateToken(token, jwks)
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
key, err := smqjwt.ToKey(tkn)
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
return authn.Session{
|
||||
Type: authn.AccessToken,
|
||||
UserID: key.Subject,
|
||||
Role: authn.Role(key.Role),
|
||||
Verified: key.Verified,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a authentication) fetchJWKS(ctx context.Context) (jwk.Set, error) {
|
||||
jwksCache.RLock()
|
||||
if time.Since(jwksCache.cachedAt) < cacheDuration && jwksCache.jwks.Len() > 0 {
|
||||
cached := jwksCache.jwks
|
||||
jwksCache.RUnlock()
|
||||
return cached, nil
|
||||
}
|
||||
jwksCache.RUnlock()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", a.jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errFetchJWKS
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
set, err := jwk.Parse(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwksCache.Lock()
|
||||
jwksCache.jwks = set
|
||||
jwksCache.cachedAt = time.Now()
|
||||
jwksCache.Unlock()
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func validateToken(token string, jwks jwk.Set) (jwt.Token, error) {
|
||||
tkn, err := jwt.Parse(
|
||||
[]byte(token),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true)),
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Contains(err, errJWTExpiryKey) {
|
||||
return nil, smqauth.ErrExpiry
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
validator := jwt.ValidatorFunc(func(_ context.Context, t jwt.Token) jwt.ValidationError {
|
||||
if t.Issuer() != issuerName {
|
||||
return jwt.NewValidationError(errInvalidIssuer)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := jwt.Validate(tkn, jwt.WithValidator(validator)); err != nil {
|
||||
return nil, errors.Wrap(errValidateJWTToken, err)
|
||||
}
|
||||
|
||||
return tkn, nil
|
||||
}
|
||||
@@ -58,6 +58,8 @@ packages:
|
||||
Cache:
|
||||
Hasher:
|
||||
KeyRepository:
|
||||
KeyManager:
|
||||
Tokenizer:
|
||||
PATS:
|
||||
PATSRepository:
|
||||
Service:
|
||||
|
||||
Reference in New Issue
Block a user