NOISSUE - Fix bugs caused by SMQ update (#301)

* Fix SMQ-caused issues

Signed-off-by: dusan <borovcanindusan1@gmail.com>

* Fix tests

Signed-off-by: dusan <borovcanindusan1@gmail.com>

---------

Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
Dušan Borovčanin
2025-09-06 21:58:41 +02:00
committed by GitHub
parent 2b97993c30
commit be7ee7a877
18 changed files with 193 additions and 64 deletions
+4 -5
View File
@@ -7,7 +7,6 @@ import (
"context"
"github.com/absmach/magistrala/alarms"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
@@ -22,7 +21,7 @@ func updateAlarmEndpoint(svc alarms.Service) endpoint.Endpoint {
return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return alarmRes{}, svcerr.ErrAuthorization
}
@@ -45,7 +44,7 @@ func viewAlarmEndpoint(svc alarms.Service) endpoint.Endpoint {
return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return alarmRes{}, svcerr.ErrAuthorization
}
@@ -68,7 +67,7 @@ func listAlarmsEndpoint(svc alarms.Service) endpoint.Endpoint {
return alarmsPageRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return alarmsPageRes{}, svcerr.ErrAuthorization
}
@@ -91,7 +90,7 @@ func deleteAlarmEndpoint(svc alarms.Service) endpoint.Endpoint {
return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return alarmRes{}, svcerr.ErrAuthorization
}
+2 -2
View File
@@ -24,7 +24,7 @@ import (
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider, instanceID string, authn smqauthn.Authentication) http.Handler {
func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider, instanceID string, authn smqauthn.AuthNMiddleware) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
@@ -33,7 +33,7 @@ func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider
mux.Route("/{domainID}/alarms", func(r chi.Router) {
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
r.Use(api.RequestIDMiddleware(idp))
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
+8 -9
View File
@@ -7,7 +7,6 @@ import (
"context"
"github.com/absmach/magistrala/bootstrap"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
@@ -22,7 +21,7 @@ func addEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -65,7 +64,7 @@ func updateCertEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -93,7 +92,7 @@ func viewEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -134,7 +133,7 @@ func updateEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -165,7 +164,7 @@ func updateConnEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -190,7 +189,7 @@ func listEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -240,7 +239,7 @@ func removeEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return removeRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -276,7 +275,7 @@ func stateEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
+2 -1
View File
@@ -180,7 +180,8 @@ func newBootstrapServer() (*httptest.Server, *mocks.Service, *authnmocks.Authent
logger := smqlog.NewMock()
svc := new(mocks.Service)
authn := new(authnmocks.Authentication)
mux := bsapi.MakeHandler(svc, authn, bootstrap.NewConfigReader(encKey), logger, instanceID)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
mux := bsapi.MakeHandler(svc, am, bootstrap.NewConfigReader(encKey), logger, instanceID)
return httptest.NewServer(mux), svc, authn
}
+3 -4
View File
@@ -41,7 +41,7 @@ var (
)
// MakeHandler returns a HTTP handler for API endpoints.
func MakeHandler(svc bootstrap.Service, authn smqauthn.Authentication, reader bootstrap.ConfigReader, logger *slog.Logger, instanceID string) http.Handler {
func MakeHandler(svc bootstrap.Service, authn smqauthn.AuthNMiddleware, reader bootstrap.ConfigReader, logger *slog.Logger, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, mgapi.EncodeError)),
}
@@ -50,8 +50,7 @@ func MakeHandler(svc bootstrap.Service, authn smqauthn.Authentication, reader bo
r.Route("/{domainID}/clients", func(r chi.Router) {
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
r.Route("/configs", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
addEndpoint(svc),
@@ -97,7 +96,7 @@ func MakeHandler(svc bootstrap.Service, authn smqauthn.Authentication, reader bo
})
})
r.With(api.AuthenticateMiddleware(authn, true)).Put("/state/{clientID}", otelhttp.NewHandler(kithttp.NewServer(
r.With(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()).Put("/state/{clientID}", otelhttp.NewHandler(kithttp.NewServer(
stateEndpoint(svc),
decodeStateRequest,
api.EncodeResponse,
+3 -1
View File
@@ -18,6 +18,7 @@ import (
alarmsRepo "github.com/absmach/magistrala/alarms/postgres"
"github.com/absmach/magistrala/pkg/prometheus"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authn/authsvc"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
@@ -108,6 +109,7 @@ func main() {
exitCode = 1
return
}
am := smqauthn.NewAuthNMiddleware(authn)
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
@@ -152,7 +154,7 @@ func main() {
exitCode = 1
return
}
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpAPI.MakeHandler(svc, logger, idp, cfg.InstanceID, authn), logger)
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpAPI.MakeHandler(svc, logger, idp, cfg.InstanceID, am), logger)
pubSub, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger)
if err != nil {
+3 -1
View File
@@ -22,6 +22,7 @@ import (
"github.com/absmach/magistrala/bootstrap/tracing"
"github.com/absmach/supermq"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -148,6 +149,7 @@ func main() {
exitCode = 1
return
}
am := smqauthn.NewAuthNMiddleware(authn)
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
defer authnClient.Close()
@@ -196,7 +198,7 @@ func main() {
exitCode = 1
return
}
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, bootstrap.NewConfigReader([]byte(cfg.EncKey)), logger, cfg.InstanceID), logger)
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, bootstrap.NewConfigReader([]byte(cfg.EncKey)), logger, cfg.InstanceID), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
+4 -1
View File
@@ -29,6 +29,7 @@ import (
grpcClient "github.com/absmach/magistrala/readers/api/grpc"
"github.com/absmach/supermq"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnsvc "github.com/absmach/supermq/pkg/authn/authsvc"
mgauthz "github.com/absmach/supermq/pkg/authz"
authzsvc "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -207,6 +208,8 @@ func main() {
return
}
am := smqauthn.NewAuthNMiddleware(authn)
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
runInfo := make(chan pkglog.RunInfo, channBuffer)
@@ -280,7 +283,7 @@ func main() {
mux := chi.NewRouter()
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID), logger)
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, mux, logger, cfg.InstanceID), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
+3 -1
View File
@@ -27,6 +27,7 @@ import (
repg "github.com/absmach/magistrala/reports/postgres"
"github.com/absmach/supermq"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnsvc "github.com/absmach/supermq/pkg/authn/authsvc"
mgauthz "github.com/absmach/supermq/pkg/authz"
authzsvc "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -183,6 +184,7 @@ func main() {
return
}
am := smqauthn.NewAuthNMiddleware(authn)
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
@@ -245,7 +247,7 @@ func main() {
mux := chi.NewRouter()
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID), logger)
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, mux, logger, cfg.InstanceID), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
+7 -7
View File
@@ -42,7 +42,7 @@ func (_m *Service) EXPECT() *Service_Expecter {
}
// ConsumeBlocking provides a mock function for the type Service
func (_mock *Service) ConsumeBlocking(ctx context.Context, messages interface{}) error {
func (_mock *Service) ConsumeBlocking(ctx context.Context, messages any) error {
ret := _mock.Called(ctx, messages)
if len(ret) == 0 {
@@ -50,7 +50,7 @@ func (_mock *Service) ConsumeBlocking(ctx context.Context, messages interface{})
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, interface{}) error); ok {
if returnFunc, ok := ret.Get(0).(func(context.Context, any) error); ok {
r0 = returnFunc(ctx, messages)
} else {
r0 = ret.Error(0)
@@ -65,20 +65,20 @@ type Service_ConsumeBlocking_Call struct {
// ConsumeBlocking is a helper method to define mock.On call
// - ctx context.Context
// - messages interface{}
// - messages any
func (_e *Service_Expecter) ConsumeBlocking(ctx interface{}, messages interface{}) *Service_ConsumeBlocking_Call {
return &Service_ConsumeBlocking_Call{Call: _e.mock.On("ConsumeBlocking", ctx, messages)}
}
func (_c *Service_ConsumeBlocking_Call) Run(run func(ctx context.Context, messages interface{})) *Service_ConsumeBlocking_Call {
func (_c *Service_ConsumeBlocking_Call) Run(run func(ctx context.Context, messages any)) *Service_ConsumeBlocking_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 interface{}
var arg1 any
if args[1] != nil {
arg1 = args[1].(interface{})
arg1 = args[1].(any)
}
run(
arg0,
@@ -93,7 +93,7 @@ func (_c *Service_ConsumeBlocking_Call) Return(err error) *Service_ConsumeBlocki
return _c
}
func (_c *Service_ConsumeBlocking_Call) RunAndReturn(run func(ctx context.Context, messages interface{}) error) *Service_ConsumeBlocking_Call {
func (_c *Service_ConsumeBlocking_Call) RunAndReturn(run func(ctx context.Context, messages any) error) *Service_ConsumeBlocking_Call {
_c.Call.Return(run)
return _c
}
+3 -1
View File
@@ -138,7 +138,9 @@ func setupBootstrap() (*httptest.Server, *bmocks.Service, *bmocks.ConfigReader,
reader := new(bmocks.ConfigReader)
logger := smqlog.NewMock()
authn := new(authnmocks.Authentication)
mux := api.MakeHandler(bsvc, authn, reader, logger, "")
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
mux := api.MakeHandler(bsvc, am, reader, logger, "")
return httptest.NewServer(mux), bsvc, reader, authn
}
+118
View File
@@ -8731,6 +8731,65 @@ func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domai
return _c
}
// SendVerification provides a mock function for the type SDK
func (_mock *SDK) SendVerification(ctx context.Context, token string) errors.SDKError {
ret := _mock.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for SendVerification")
}
var r0 errors.SDKError
if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok {
r0 = returnFunc(ctx, token)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(errors.SDKError)
}
}
return r0
}
// SDK_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification'
type SDK_SendVerification_Call struct {
*mock.Call
}
// SendVerification is a helper method to define mock.On call
// - ctx context.Context
// - token string
func (_e *SDK_Expecter) SendVerification(ctx interface{}, token interface{}) *SDK_SendVerification_Call {
return &SDK_SendVerification_Call{Call: _e.mock.On("SendVerification", ctx, token)}
}
func (_c *SDK_SendVerification_Call) Run(run func(ctx context.Context, token string)) *SDK_SendVerification_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *SDK_SendVerification_Call) Return(sDKError errors.SDKError) *SDK_SendVerification_Call {
_c.Call.Return(sDKError)
return _c
}
func (_c *SDK_SendVerification_Call) RunAndReturn(run func(ctx context.Context, token string) errors.SDKError) *SDK_SendVerification_Call {
_c.Call.Return(run)
return _c
}
// SetChannelParent provides a mock function for the type SDK
func (_mock *SDK) SetChannelParent(ctx context.Context, id string, domainID string, groupID string, token string) errors.SDKError {
ret := _mock.Called(ctx, id, domainID, groupID, token)
@@ -10911,6 +10970,65 @@ func (_c *SDK_Users_Call) RunAndReturn(run func(ctx context.Context, pm sdk0.Pag
return _c
}
// VerifyEmail provides a mock function for the type SDK
func (_mock *SDK) VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError {
ret := _mock.Called(ctx, verificationToken)
if len(ret) == 0 {
panic("no return value specified for VerifyEmail")
}
var r0 errors.SDKError
if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok {
r0 = returnFunc(ctx, verificationToken)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(errors.SDKError)
}
}
return r0
}
// SDK_VerifyEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEmail'
type SDK_VerifyEmail_Call struct {
*mock.Call
}
// VerifyEmail is a helper method to define mock.On call
// - ctx context.Context
// - verificationToken string
func (_e *SDK_Expecter) VerifyEmail(ctx interface{}, verificationToken interface{}) *SDK_VerifyEmail_Call {
return &SDK_VerifyEmail_Call{Call: _e.mock.On("VerifyEmail", ctx, verificationToken)}
}
func (_c *SDK_VerifyEmail_Call) Run(run func(ctx context.Context, verificationToken string)) *SDK_VerifyEmail_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *SDK_VerifyEmail_Call) Return(sDKError errors.SDKError) *SDK_VerifyEmail_Call {
_c.Call.Return(sDKError)
return _c
}
func (_c *SDK_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verificationToken string) errors.SDKError) *SDK_VerifyEmail_Call {
_c.Call.Return(run)
return _c
}
// ViewBootstrap provides a mock function for the type SDK
func (_mock *SDK) ViewBootstrap(ctx context.Context, id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) {
ret := _mock.Called(ctx, id, domainID, token)
+9 -10
View File
@@ -7,7 +7,6 @@ import (
"context"
"github.com/absmach/magistrala/re"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
@@ -17,7 +16,7 @@ import (
func addRuleEndpoint(s re.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -36,7 +35,7 @@ func addRuleEndpoint(s re.Service) endpoint.Endpoint {
func viewRuleEndpoint(s re.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -55,7 +54,7 @@ func viewRuleEndpoint(s re.Service) endpoint.Endpoint {
func updateRuleEndpoint(s re.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -79,7 +78,7 @@ func updateRuleTagsEndpoint(svc re.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -99,7 +98,7 @@ func updateRuleTagsEndpoint(svc re.Service) endpoint.Endpoint {
func updateRuleScheduleEndpoint(s re.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -124,7 +123,7 @@ func updateRuleScheduleEndpoint(s re.Service) endpoint.Endpoint {
func listRulesEndpoint(s re.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -146,7 +145,7 @@ func listRulesEndpoint(s re.Service) endpoint.Endpoint {
func deleteRuleEndpoint(s re.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -165,7 +164,7 @@ func deleteRuleEndpoint(s re.Service) endpoint.Endpoint {
func enableRuleEndpoint(s re.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -186,7 +185,7 @@ func enableRuleEndpoint(s re.Service) endpoint.Endpoint {
func disableRuleEndpoint(s re.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
+3 -1
View File
@@ -100,7 +100,9 @@ func newRuleEngineServer() (*httptest.Server, *mocks.Service, *authnmocks.Authen
logger := smqlog.NewMock()
mux := chi.NewRouter()
api.MakeHandler(svc, authn, mux, logger, "")
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
api.MakeHandler(svc, am, mux, logger, "")
return httptest.NewServer(mux), svc, authn
}
+3 -3
View File
@@ -14,7 +14,7 @@ import (
"github.com/absmach/supermq"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
mgauthn "github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
@@ -28,12 +28,12 @@ const (
)
// MakeHandler creates an HTTP handler for the service endpoints.
func MakeHandler(svc re.Service, authn mgauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler {
func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
mux.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
r.Route("/{domainID}", func(r chi.Router) {
r.Route("/rules", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
+12 -13
View File
@@ -7,7 +7,6 @@ import (
"context"
"github.com/absmach/magistrala/reports"
api "github.com/absmach/supermq/api/http"
"github.com/absmach/supermq/pkg/authn"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/go-kit/kit/endpoint"
@@ -15,7 +14,7 @@ import (
func generateReportEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -51,7 +50,7 @@ func generateReportEndpoint(svc reports.Service) endpoint.Endpoint {
func listReportsConfigEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -79,7 +78,7 @@ func listReportsConfigEndpoint(svc reports.Service) endpoint.Endpoint {
func deleteReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -100,7 +99,7 @@ func deleteReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
func updateReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -121,7 +120,7 @@ func updateReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
func updateReportScheduleEndpoint(s reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -146,7 +145,7 @@ func updateReportScheduleEndpoint(s reports.Service) endpoint.Endpoint {
func viewReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -167,7 +166,7 @@ func viewReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
func addReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -191,7 +190,7 @@ func addReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
func enableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -212,7 +211,7 @@ func enableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
func disableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -233,7 +232,7 @@ func disableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint {
func updateReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -254,7 +253,7 @@ func updateReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint {
func viewReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -275,7 +274,7 @@ func viewReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint {
func deleteReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
+3 -1
View File
@@ -109,7 +109,9 @@ func newReportsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentic
logger := smqlog.NewMock()
mux := chi.NewRouter()
api.MakeHandler(svc, authn, mux, logger, "")
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
api.MakeHandler(svc, am, mux, logger, "")
return httptest.NewServer(mux), svc, authn
}
+3 -3
View File
@@ -15,7 +15,7 @@ import (
"github.com/absmach/supermq"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
mgauthn "github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
@@ -30,12 +30,12 @@ const (
)
// MakeHandler creates an HTTP handler for the service endpoints.
func MakeHandler(svc reports.Service, authn mgauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler {
func MakeHandler(svc reports.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
mux.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
r.Route("/{domainID}", func(r chi.Router) {
r.Route("/reports", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(