diff --git a/alarms/api/endpoint.go b/alarms/api/endpoint.go index c53b246d8..10b881aa9 100644 --- a/alarms/api/endpoint.go +++ b/alarms/api/endpoint.go @@ -7,7 +7,6 @@ import ( "context" "github.com/absmach/magistrala/alarms" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" @@ -22,7 +21,7 @@ func updateAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return alarmRes{}, svcerr.ErrAuthorization } @@ -45,7 +44,7 @@ func viewAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return alarmRes{}, svcerr.ErrAuthorization } @@ -68,7 +67,7 @@ func listAlarmsEndpoint(svc alarms.Service) endpoint.Endpoint { return alarmsPageRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return alarmsPageRes{}, svcerr.ErrAuthorization } @@ -91,7 +90,7 @@ func deleteAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return alarmRes{}, svcerr.ErrAuthorization } diff --git a/alarms/api/transport.go b/alarms/api/transport.go index e3daa76ef..3087a18aa 100644 --- a/alarms/api/transport.go +++ b/alarms/api/transport.go @@ -24,7 +24,7 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) -func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider, instanceID string, authn smqauthn.Authentication) http.Handler { +func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider, instanceID string, authn smqauthn.AuthNMiddleware) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } @@ -33,7 +33,7 @@ func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider mux.Route("/{domainID}/alarms", func(r chi.Router) { r.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) r.Use(api.RequestIDMiddleware(idp)) r.Get("/", otelhttp.NewHandler(kithttp.NewServer( diff --git a/bootstrap/api/endpoint.go b/bootstrap/api/endpoint.go index 12822c050..3f9074894 100644 --- a/bootstrap/api/endpoint.go +++ b/bootstrap/api/endpoint.go @@ -7,7 +7,6 @@ import ( "context" "github.com/absmach/magistrala/bootstrap" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" @@ -22,7 +21,7 @@ func addEndpoint(svc bootstrap.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -65,7 +64,7 @@ func updateCertEndpoint(svc bootstrap.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -93,7 +92,7 @@ func viewEndpoint(svc bootstrap.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -134,7 +133,7 @@ func updateEndpoint(svc bootstrap.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -165,7 +164,7 @@ func updateConnEndpoint(svc bootstrap.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -190,7 +189,7 @@ func listEndpoint(svc bootstrap.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -240,7 +239,7 @@ func removeEndpoint(svc bootstrap.Service) endpoint.Endpoint { return removeRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -276,7 +275,7 @@ func stateEndpoint(svc bootstrap.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } diff --git a/bootstrap/api/endpoint_test.go b/bootstrap/api/endpoint_test.go index 4f536ccff..8d74cc796 100644 --- a/bootstrap/api/endpoint_test.go +++ b/bootstrap/api/endpoint_test.go @@ -180,7 +180,8 @@ func newBootstrapServer() (*httptest.Server, *mocks.Service, *authnmocks.Authent logger := smqlog.NewMock() svc := new(mocks.Service) authn := new(authnmocks.Authentication) - mux := bsapi.MakeHandler(svc, authn, bootstrap.NewConfigReader(encKey), logger, instanceID) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := bsapi.MakeHandler(svc, am, bootstrap.NewConfigReader(encKey), logger, instanceID) return httptest.NewServer(mux), svc, authn } diff --git a/bootstrap/api/transport.go b/bootstrap/api/transport.go index 0c480f7fb..547cc30e4 100644 --- a/bootstrap/api/transport.go +++ b/bootstrap/api/transport.go @@ -41,7 +41,7 @@ var ( ) // MakeHandler returns a HTTP handler for API endpoints. -func MakeHandler(svc bootstrap.Service, authn smqauthn.Authentication, reader bootstrap.ConfigReader, logger *slog.Logger, instanceID string) http.Handler { +func MakeHandler(svc bootstrap.Service, authn smqauthn.AuthNMiddleware, reader bootstrap.ConfigReader, logger *slog.Logger, instanceID string) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, mgapi.EncodeError)), } @@ -50,8 +50,7 @@ func MakeHandler(svc bootstrap.Service, authn smqauthn.Authentication, reader bo r.Route("/{domainID}/clients", func(r chi.Router) { r.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) - + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) r.Route("/configs", func(r chi.Router) { r.Post("/", otelhttp.NewHandler(kithttp.NewServer( addEndpoint(svc), @@ -97,7 +96,7 @@ func MakeHandler(svc bootstrap.Service, authn smqauthn.Authentication, reader bo }) }) - r.With(api.AuthenticateMiddleware(authn, true)).Put("/state/{clientID}", otelhttp.NewHandler(kithttp.NewServer( + r.With(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()).Put("/state/{clientID}", otelhttp.NewHandler(kithttp.NewServer( stateEndpoint(svc), decodeStateRequest, api.EncodeResponse, diff --git a/cmd/alarms/main.go b/cmd/alarms/main.go index 736faff76..a213689b0 100644 --- a/cmd/alarms/main.go +++ b/cmd/alarms/main.go @@ -18,6 +18,7 @@ import ( alarmsRepo "github.com/absmach/magistrala/alarms/postgres" "github.com/absmach/magistrala/pkg/prometheus" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/authn/authsvc" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" @@ -108,6 +109,7 @@ func main() { exitCode = 1 return } + am := smqauthn.NewAuthNMiddleware(authn) defer authnClient.Close() logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) @@ -152,7 +154,7 @@ func main() { exitCode = 1 return } - hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpAPI.MakeHandler(svc, logger, idp, cfg.InstanceID, authn), logger) + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpAPI.MakeHandler(svc, logger, idp, cfg.InstanceID, am), logger) pubSub, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) if err != nil { diff --git a/cmd/bootstrap/main.go b/cmd/bootstrap/main.go index 1da07bb8d..e36c15e7c 100644 --- a/cmd/bootstrap/main.go +++ b/cmd/bootstrap/main.go @@ -22,6 +22,7 @@ import ( "github.com/absmach/magistrala/bootstrap/tracing" "github.com/absmach/supermq" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" @@ -148,6 +149,7 @@ func main() { exitCode = 1 return } + am := smqauthn.NewAuthNMiddleware(authn) logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) defer authnClient.Close() @@ -196,7 +198,7 @@ func main() { exitCode = 1 return } - hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, bootstrap.NewConfigReader([]byte(cfg.EncKey)), logger, cfg.InstanceID), logger) + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, bootstrap.NewConfigReader([]byte(cfg.EncKey)), logger, cfg.InstanceID), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) diff --git a/cmd/re/main.go b/cmd/re/main.go index f9a586e84..f6945bed6 100644 --- a/cmd/re/main.go +++ b/cmd/re/main.go @@ -29,6 +29,7 @@ import ( grpcClient "github.com/absmach/magistrala/readers/api/grpc" "github.com/absmach/supermq" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authnsvc "github.com/absmach/supermq/pkg/authn/authsvc" mgauthz "github.com/absmach/supermq/pkg/authz" authzsvc "github.com/absmach/supermq/pkg/authz/authsvc" @@ -207,6 +208,8 @@ func main() { return } + am := smqauthn.NewAuthNMiddleware(authn) + defer authnClient.Close() logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) runInfo := make(chan pkglog.RunInfo, channBuffer) @@ -280,7 +283,7 @@ func main() { mux := chi.NewRouter() - httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID), logger) + httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, mux, logger, cfg.InstanceID), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) diff --git a/cmd/reports/main.go b/cmd/reports/main.go index 0da35a201..049d46d84 100644 --- a/cmd/reports/main.go +++ b/cmd/reports/main.go @@ -27,6 +27,7 @@ import ( repg "github.com/absmach/magistrala/reports/postgres" "github.com/absmach/supermq" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authnsvc "github.com/absmach/supermq/pkg/authn/authsvc" mgauthz "github.com/absmach/supermq/pkg/authz" authzsvc "github.com/absmach/supermq/pkg/authz/authsvc" @@ -183,6 +184,7 @@ func main() { return } + am := smqauthn.NewAuthNMiddleware(authn) defer authnClient.Close() logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) @@ -245,7 +247,7 @@ func main() { mux := chi.NewRouter() - httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID), logger) + httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, mux, logger, cfg.InstanceID), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) diff --git a/consumers/notifiers/mocks/service.go b/consumers/notifiers/mocks/service.go index 00e3f698a..fc4c1aef5 100644 --- a/consumers/notifiers/mocks/service.go +++ b/consumers/notifiers/mocks/service.go @@ -42,7 +42,7 @@ func (_m *Service) EXPECT() *Service_Expecter { } // ConsumeBlocking provides a mock function for the type Service -func (_mock *Service) ConsumeBlocking(ctx context.Context, messages interface{}) error { +func (_mock *Service) ConsumeBlocking(ctx context.Context, messages any) error { ret := _mock.Called(ctx, messages) if len(ret) == 0 { @@ -50,7 +50,7 @@ func (_mock *Service) ConsumeBlocking(ctx context.Context, messages interface{}) } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, interface{}) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, any) error); ok { r0 = returnFunc(ctx, messages) } else { r0 = ret.Error(0) @@ -65,20 +65,20 @@ type Service_ConsumeBlocking_Call struct { // ConsumeBlocking is a helper method to define mock.On call // - ctx context.Context -// - messages interface{} +// - messages any func (_e *Service_Expecter) ConsumeBlocking(ctx interface{}, messages interface{}) *Service_ConsumeBlocking_Call { return &Service_ConsumeBlocking_Call{Call: _e.mock.On("ConsumeBlocking", ctx, messages)} } -func (_c *Service_ConsumeBlocking_Call) Run(run func(ctx context.Context, messages interface{})) *Service_ConsumeBlocking_Call { +func (_c *Service_ConsumeBlocking_Call) Run(run func(ctx context.Context, messages any)) *Service_ConsumeBlocking_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 interface{} + var arg1 any if args[1] != nil { - arg1 = args[1].(interface{}) + arg1 = args[1].(any) } run( arg0, @@ -93,7 +93,7 @@ func (_c *Service_ConsumeBlocking_Call) Return(err error) *Service_ConsumeBlocki return _c } -func (_c *Service_ConsumeBlocking_Call) RunAndReturn(run func(ctx context.Context, messages interface{}) error) *Service_ConsumeBlocking_Call { +func (_c *Service_ConsumeBlocking_Call) RunAndReturn(run func(ctx context.Context, messages any) error) *Service_ConsumeBlocking_Call { _c.Call.Return(run) return _c } diff --git a/pkg/sdk/bootstrap_test.go b/pkg/sdk/bootstrap_test.go index a70074535..db10af191 100644 --- a/pkg/sdk/bootstrap_test.go +++ b/pkg/sdk/bootstrap_test.go @@ -138,7 +138,9 @@ func setupBootstrap() (*httptest.Server, *bmocks.Service, *bmocks.ConfigReader, reader := new(bmocks.ConfigReader) logger := smqlog.NewMock() authn := new(authnmocks.Authentication) - mux := api.MakeHandler(bsvc, authn, reader, logger, "") + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + + mux := api.MakeHandler(bsvc, am, reader, logger, "") return httptest.NewServer(mux), bsvc, reader, authn } diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index ed1459693..8fafb7b70 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -8731,6 +8731,65 @@ func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domai return _c } +// SendVerification provides a mock function for the type SDK +func (_mock *SDK) SendVerification(ctx context.Context, token string) errors.SDKError { + ret := _mock.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for SendVerification") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok { + r0 = returnFunc(ctx, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification' +type SDK_SendVerification_Call struct { + *mock.Call +} + +// SendVerification is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *SDK_Expecter) SendVerification(ctx interface{}, token interface{}) *SDK_SendVerification_Call { + return &SDK_SendVerification_Call{Call: _e.mock.On("SendVerification", ctx, token)} +} + +func (_c *SDK_SendVerification_Call) Run(run func(ctx context.Context, token string)) *SDK_SendVerification_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SDK_SendVerification_Call) Return(sDKError errors.SDKError) *SDK_SendVerification_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_SendVerification_Call) RunAndReturn(run func(ctx context.Context, token string) errors.SDKError) *SDK_SendVerification_Call { + _c.Call.Return(run) + return _c +} + // SetChannelParent provides a mock function for the type SDK func (_mock *SDK) SetChannelParent(ctx context.Context, id string, domainID string, groupID string, token string) errors.SDKError { ret := _mock.Called(ctx, id, domainID, groupID, token) @@ -10911,6 +10970,65 @@ func (_c *SDK_Users_Call) RunAndReturn(run func(ctx context.Context, pm sdk0.Pag return _c } +// VerifyEmail provides a mock function for the type SDK +func (_mock *SDK) VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError { + ret := _mock.Called(ctx, verificationToken) + + if len(ret) == 0 { + panic("no return value specified for VerifyEmail") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok { + r0 = returnFunc(ctx, verificationToken) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_VerifyEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEmail' +type SDK_VerifyEmail_Call struct { + *mock.Call +} + +// VerifyEmail is a helper method to define mock.On call +// - ctx context.Context +// - verificationToken string +func (_e *SDK_Expecter) VerifyEmail(ctx interface{}, verificationToken interface{}) *SDK_VerifyEmail_Call { + return &SDK_VerifyEmail_Call{Call: _e.mock.On("VerifyEmail", ctx, verificationToken)} +} + +func (_c *SDK_VerifyEmail_Call) Run(run func(ctx context.Context, verificationToken string)) *SDK_VerifyEmail_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SDK_VerifyEmail_Call) Return(sDKError errors.SDKError) *SDK_VerifyEmail_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verificationToken string) errors.SDKError) *SDK_VerifyEmail_Call { + _c.Call.Return(run) + return _c +} + // ViewBootstrap provides a mock function for the type SDK func (_mock *SDK) ViewBootstrap(ctx context.Context, id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { ret := _mock.Called(ctx, id, domainID, token) diff --git a/re/api/endpoints.go b/re/api/endpoints.go index 0088dd8d5..fa8c3b356 100644 --- a/re/api/endpoints.go +++ b/re/api/endpoints.go @@ -7,7 +7,6 @@ import ( "context" "github.com/absmach/magistrala/re" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" @@ -17,7 +16,7 @@ import ( func addRuleEndpoint(s re.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -36,7 +35,7 @@ func addRuleEndpoint(s re.Service) endpoint.Endpoint { func viewRuleEndpoint(s re.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -55,7 +54,7 @@ func viewRuleEndpoint(s re.Service) endpoint.Endpoint { func updateRuleEndpoint(s re.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -79,7 +78,7 @@ func updateRuleTagsEndpoint(svc re.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -99,7 +98,7 @@ func updateRuleTagsEndpoint(svc re.Service) endpoint.Endpoint { func updateRuleScheduleEndpoint(s re.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -124,7 +123,7 @@ func updateRuleScheduleEndpoint(s re.Service) endpoint.Endpoint { func listRulesEndpoint(s re.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -146,7 +145,7 @@ func listRulesEndpoint(s re.Service) endpoint.Endpoint { func deleteRuleEndpoint(s re.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -165,7 +164,7 @@ func deleteRuleEndpoint(s re.Service) endpoint.Endpoint { func enableRuleEndpoint(s re.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -186,7 +185,7 @@ func enableRuleEndpoint(s re.Service) endpoint.Endpoint { func disableRuleEndpoint(s re.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } diff --git a/re/api/endpoints_test.go b/re/api/endpoints_test.go index 5b802d86f..39f127f85 100644 --- a/re/api/endpoints_test.go +++ b/re/api/endpoints_test.go @@ -100,7 +100,9 @@ func newRuleEngineServer() (*httptest.Server, *mocks.Service, *authnmocks.Authen logger := smqlog.NewMock() mux := chi.NewRouter() - api.MakeHandler(svc, authn, mux, logger, "") + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + + api.MakeHandler(svc, am, mux, logger, "") return httptest.NewServer(mux), svc, authn } diff --git a/re/api/transport.go b/re/api/transport.go index 25ba9cc32..54f0a41f9 100644 --- a/re/api/transport.go +++ b/re/api/transport.go @@ -14,7 +14,7 @@ import ( "github.com/absmach/supermq" api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" - mgauthn "github.com/absmach/supermq/pkg/authn" + smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" "github.com/go-chi/chi/v5" kithttp "github.com/go-kit/kit/transport/http" @@ -28,12 +28,12 @@ const ( ) // MakeHandler creates an HTTP handler for the service endpoints. -func MakeHandler(svc re.Service, authn mgauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler { +func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } mux.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) r.Route("/{domainID}", func(r chi.Router) { r.Route("/rules", func(r chi.Router) { r.Post("/", otelhttp.NewHandler(kithttp.NewServer( diff --git a/reports/api/endpoints.go b/reports/api/endpoints.go index 954816acd..b1c0b08c8 100644 --- a/reports/api/endpoints.go +++ b/reports/api/endpoints.go @@ -7,7 +7,6 @@ import ( "context" "github.com/absmach/magistrala/reports" - api "github.com/absmach/supermq/api/http" "github.com/absmach/supermq/pkg/authn" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/go-kit/kit/endpoint" @@ -15,7 +14,7 @@ import ( func generateReportEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -51,7 +50,7 @@ func generateReportEndpoint(svc reports.Service) endpoint.Endpoint { func listReportsConfigEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -79,7 +78,7 @@ func listReportsConfigEndpoint(svc reports.Service) endpoint.Endpoint { func deleteReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -100,7 +99,7 @@ func deleteReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { func updateReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -121,7 +120,7 @@ func updateReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { func updateReportScheduleEndpoint(s reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -146,7 +145,7 @@ func updateReportScheduleEndpoint(s reports.Service) endpoint.Endpoint { func viewReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -167,7 +166,7 @@ func viewReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { func addReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -191,7 +190,7 @@ func addReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { func enableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -212,7 +211,7 @@ func enableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { func disableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -233,7 +232,7 @@ func disableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { func updateReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -254,7 +253,7 @@ func updateReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint { func viewReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -275,7 +274,7 @@ func viewReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint { func deleteReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } diff --git a/reports/api/endpoints_test.go b/reports/api/endpoints_test.go index e4d52b3bb..3ab77e3cd 100644 --- a/reports/api/endpoints_test.go +++ b/reports/api/endpoints_test.go @@ -109,7 +109,9 @@ func newReportsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentic logger := smqlog.NewMock() mux := chi.NewRouter() - api.MakeHandler(svc, authn, mux, logger, "") + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + + api.MakeHandler(svc, am, mux, logger, "") return httptest.NewServer(mux), svc, authn } diff --git a/reports/api/transport.go b/reports/api/transport.go index 54f71ad50..236455b58 100644 --- a/reports/api/transport.go +++ b/reports/api/transport.go @@ -15,7 +15,7 @@ import ( "github.com/absmach/supermq" api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" - mgauthn "github.com/absmach/supermq/pkg/authn" + smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" "github.com/go-chi/chi/v5" kithttp "github.com/go-kit/kit/transport/http" @@ -30,12 +30,12 @@ const ( ) // MakeHandler creates an HTTP handler for the service endpoints. -func MakeHandler(svc reports.Service, authn mgauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler { +func MakeHandler(svc reports.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } mux.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) r.Route("/{domainID}", func(r chi.Router) { r.Route("/reports", func(r chi.Router) { r.Post("/", otelhttp.NewHandler(kithttp.NewServer(