mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
NOISSUE - Update mGate version in http and ws adapters (#2825)
Signed-off-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
+17
-13
@@ -43,14 +43,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "http_adapter"
|
||||
envPrefix = "SMQ_HTTP_ADAPTER_"
|
||||
envPrefixClients = "SMQ_CLIENTS_GRPC_"
|
||||
envPrefixChannels = "SMQ_CHANNELS_GRPC_"
|
||||
envPrefixAuth = "SMQ_AUTH_GRPC_"
|
||||
defSvcHTTPPort = "80"
|
||||
targetHTTPPort = "81"
|
||||
targetHTTPHost = "http://localhost"
|
||||
svcName = "http_adapter"
|
||||
envPrefix = "SMQ_HTTP_ADAPTER_"
|
||||
envPrefixClients = "SMQ_CLIENTS_GRPC_"
|
||||
envPrefixChannels = "SMQ_CHANNELS_GRPC_"
|
||||
envPrefixAuth = "SMQ_AUTH_GRPC_"
|
||||
defSvcHTTPPort = "80"
|
||||
targetHTTPProtocol = "http"
|
||||
targetHTTPHost = "localhost"
|
||||
targetHTTPPort = "81"
|
||||
targetHTTPPath = ""
|
||||
)
|
||||
|
||||
type config struct {
|
||||
@@ -210,9 +212,11 @@ func newService(pub messaging.Publisher, authn smqauthn.Authentication, clients
|
||||
|
||||
func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sessionHandler session.Handler) error {
|
||||
config := mgate.Config{
|
||||
Address: fmt.Sprintf("%s:%s", "", cfg.Port),
|
||||
Target: fmt.Sprintf("%s:%s", targetHTTPHost, targetHTTPPort),
|
||||
PathPrefix: "/",
|
||||
Port: cfg.Port,
|
||||
TargetProtocol: targetHTTPProtocol,
|
||||
TargetHost: targetHTTPHost,
|
||||
TargetPort: targetHTTPPort,
|
||||
TargetPath: targetHTTPPath,
|
||||
}
|
||||
if cfg.CertFile != "" || cfg.KeyFile != "" {
|
||||
tlsCert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
|
||||
@@ -223,7 +227,7 @@ func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sess
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
}
|
||||
}
|
||||
mp, err := mgatehttp.NewProxy(config, sessionHandler, logger)
|
||||
mp, err := mgatehttp.NewProxy(config, sessionHandler, logger, []string{}, []string{"/health", "/metrics"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -245,7 +249,7 @@ func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sess
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info(fmt.Sprintf("proxy HTTP shutdown at %s", config.Target))
|
||||
logger.Info(fmt.Sprintf("proxy HTTP shutdown at %s:%s", config.Host, config.Port))
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
|
||||
+15
-9
@@ -53,6 +53,7 @@ const (
|
||||
type config struct {
|
||||
LogLevel string `env:"SMQ_MQTT_ADAPTER_LOG_LEVEL" envDefault:"info"`
|
||||
MQTTPort string `env:"SMQ_MQTT_ADAPTER_MQTT_PORT" envDefault:"1883"`
|
||||
MQTTTargetProtocol string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_PROTOCOL" envDefault:"mqtt"`
|
||||
MQTTTargetHost string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST" envDefault:"localhost"`
|
||||
MQTTTargetPort string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT" envDefault:"1883"`
|
||||
MQTTTargetUsername string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_USERNAME" envDefault:""`
|
||||
@@ -61,6 +62,7 @@ type config struct {
|
||||
MQTTTargetHealthCheck string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK" envDefault:""`
|
||||
MQTTQoS uint8 `env:"SMQ_MQTT_ADAPTER_MQTT_QOS" envDefault:"1"`
|
||||
HTTPPort string `env:"SMQ_MQTT_ADAPTER_WS_PORT" envDefault:"8080"`
|
||||
HTTPTargetProtocol string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PROTOCOL" envDefault:"http"`
|
||||
HTTPTargetHost string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_HOST" envDefault:"localhost"`
|
||||
HTTPTargetPort string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PORT" envDefault:"8080"`
|
||||
HTTPTargetPath string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PATH" envDefault:"/mqtt"`
|
||||
@@ -250,10 +252,11 @@ func main() {
|
||||
|
||||
func proxyMQTT(ctx context.Context, cfg config, logger *slog.Logger, sessionHandler session.Handler, interceptor session.Interceptor) error {
|
||||
config := mgate.Config{
|
||||
Address: fmt.Sprintf(":%s", cfg.MQTTPort),
|
||||
Target: fmt.Sprintf("%s:%s", cfg.MQTTTargetHost, cfg.MQTTTargetPort),
|
||||
Port: cfg.MQTTPort,
|
||||
TargetHost: cfg.MQTTTargetHost,
|
||||
TargetPort: cfg.MQTTTargetPort,
|
||||
}
|
||||
mproxy := mgatemqtt.New(config, sessionHandler, interceptor, logger)
|
||||
mproxy := mgatemqtt.New(config, sessionHandler, nil, interceptor, logger)
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
@@ -262,7 +265,7 @@ func proxyMQTT(ctx context.Context, cfg config, logger *slog.Logger, sessionHand
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info(fmt.Sprintf("proxy MQTT shutdown at %s", config.Target))
|
||||
logger.Info(fmt.Sprintf("proxy MQTT shutdown at %s:%s", config.Host, config.Port))
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
@@ -271,12 +274,15 @@ func proxyMQTT(ctx context.Context, cfg config, logger *slog.Logger, sessionHand
|
||||
|
||||
func proxyWS(ctx context.Context, cfg config, logger *slog.Logger, sessionHandler session.Handler, interceptor session.Interceptor) error {
|
||||
config := mgate.Config{
|
||||
Address: fmt.Sprintf("%s:%s", "", cfg.HTTPPort),
|
||||
Target: fmt.Sprintf("ws://%s:%s%s", cfg.HTTPTargetHost, cfg.HTTPTargetPort, cfg.HTTPTargetPath),
|
||||
PathPrefix: wsPathPrefix,
|
||||
Port: cfg.HTTPPort,
|
||||
TargetProtocol: "http",
|
||||
TargetHost: cfg.HTTPTargetHost,
|
||||
TargetPort: cfg.HTTPTargetPort,
|
||||
TargetPath: cfg.HTTPTargetPath,
|
||||
PathPrefix: wsPathPrefix,
|
||||
}
|
||||
|
||||
wp := websocket.New(config, sessionHandler, interceptor, logger)
|
||||
wp := websocket.New(config, sessionHandler, nil, interceptor, logger)
|
||||
http.HandleFunc(wsPathPrefix, wp.ServeHTTP)
|
||||
|
||||
errCh := make(chan error)
|
||||
@@ -287,7 +293,7 @@ func proxyWS(ctx context.Context, cfg config, logger *slog.Logger, sessionHandle
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info(fmt.Sprintf("proxy MQTT WS shutdown at %s", config.Target))
|
||||
logger.Info(fmt.Sprintf("proxy MQTT WS shutdown at %s:%s", config.Host, config.Port))
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
|
||||
+18
-16
@@ -13,8 +13,9 @@ import (
|
||||
"os"
|
||||
|
||||
chclient "github.com/absmach/callhome/pkg/client"
|
||||
"github.com/absmach/mgate"
|
||||
"github.com/absmach/mgate/pkg/http"
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
"github.com/absmach/mgate/pkg/websockets"
|
||||
"github.com/absmach/supermq"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
@@ -45,8 +46,9 @@ const (
|
||||
envPrefixChannels = "SMQ_CHANNELS_GRPC_"
|
||||
envPrefixAuth = "SMQ_AUTH_GRPC_"
|
||||
defSvcHTTPPort = "8190"
|
||||
targetWSPort = "8191"
|
||||
targetWSProtocol = "http"
|
||||
targetWSHost = "localhost"
|
||||
targetWSPort = "8191"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
@@ -184,9 +186,10 @@ func main() {
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
g.Go(func() error {
|
||||
return hs.Start()
|
||||
})
|
||||
return hs.Start()
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient)
|
||||
return proxyWS(ctx, httpServerConfig, targetServerConfig, logger, handler)
|
||||
})
|
||||
@@ -210,9 +213,14 @@ func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcC
|
||||
}
|
||||
|
||||
func proxyWS(ctx context.Context, hostConfig, targetConfig server.Config, logger *slog.Logger, handler session.Handler) error {
|
||||
target := fmt.Sprintf("ws://%s:%s", targetConfig.Host, targetConfig.Port)
|
||||
address := fmt.Sprintf("%s:%s", hostConfig.Host, hostConfig.Port)
|
||||
wp, err := websockets.NewProxy(address, target, logger, handler)
|
||||
config := mgate.Config{
|
||||
Host: hostConfig.Host,
|
||||
Port: hostConfig.Port,
|
||||
TargetProtocol: targetWSProtocol,
|
||||
TargetHost: targetWSHost,
|
||||
TargetPort: targetWSPort,
|
||||
}
|
||||
wp, err := http.NewProxy(config, handler, logger, []string{}, []string{"/health", "/metrics"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -220,18 +228,12 @@ func proxyWS(ctx context.Context, hostConfig, targetConfig server.Config, logger
|
||||
errCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
if hostConfig.CertFile != "" && hostConfig.KeyFile != "" {
|
||||
logger.Info(fmt.Sprintf("ws-adapter service HTTP server listening at %s:%s with TLS", hostConfig.Host, hostConfig.Port))
|
||||
errCh <- wp.ListenTLS(hostConfig.CertFile, hostConfig.KeyFile)
|
||||
} else {
|
||||
logger.Info(fmt.Sprintf("ws-adapter service HTTP server listening at %s:%s without TLS", hostConfig.Host, hostConfig.Port))
|
||||
errCh <- wp.Listen()
|
||||
}
|
||||
errCh <- wp.Listen(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info(fmt.Sprintf("proxy MQTT WS shutdown at %s", target))
|
||||
logger.Info(fmt.Sprintf("ws-adapter service shutdown at %s:%s", hostConfig.Host, hostConfig.Port))
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
|
||||
@@ -42,11 +42,13 @@ SMQ_MESSAGE_BROKER_URL=${SMQ_NATS_URL}
|
||||
SMQ_MQTT_BROKER_TYPE=rabbitmq
|
||||
SMQ_MQTT_BROKER_HEALTH_CHECK=
|
||||
SMQ_MQTT_ADAPTER_MQTT_QOS=${SMQ_RABBITMQ_MQTT_QOS}
|
||||
SMQ_MQTT_ADAPTER_MQTT_TARGET_PROTOCOL=mqtt
|
||||
SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST=${SMQ_MQTT_BROKER_TYPE}
|
||||
SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT=1883
|
||||
SMQ_MQTT_ADAPTER_MQTT_TARGET_USERNAME=${SMQ_RABBITMQ_USER}
|
||||
SMQ_MQTT_ADAPTER_MQTT_TARGET_PASSWORD=${SMQ_RABBITMQ_PASS}
|
||||
SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK=${SMQ_MQTT_BROKER_HEALTH_CHECK}
|
||||
SMQ_MQTT_ADAPTER_WS_TARGET_PROTOCOL=http
|
||||
SMQ_MQTT_ADAPTER_WS_TARGET_HOST=${SMQ_MQTT_BROKER_TYPE}
|
||||
SMQ_MQTT_ADAPTER_WS_TARGET_PORT=${SMQ_RABBITMQ_WS_PORT}
|
||||
SMQ_MQTT_ADAPTER_WS_TARGET_PATH=${SMQ_RABBITMQ_WS_TARGET_PATH}
|
||||
|
||||
@@ -6,7 +6,7 @@ require (
|
||||
github.com/0x6flab/namegenerator v1.4.0
|
||||
github.com/absmach/callhome v0.14.0
|
||||
github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02
|
||||
github.com/absmach/mgate v0.4.5
|
||||
github.com/absmach/mgate v0.4.6-0.20250425104654-79c62d581921
|
||||
github.com/absmach/senml v1.0.7
|
||||
github.com/authzed/authzed-go v1.4.0
|
||||
github.com/authzed/grpcutil v0.0.0-20250221190651-1985b19b35b8
|
||||
|
||||
@@ -21,8 +21,8 @@ github.com/absmach/callhome v0.14.0 h1:zB4tIZJ1YUmZ1VGHFPfMA/Lo6/Mv19y2dvoOiXj2B
|
||||
github.com/absmach/callhome v0.14.0/go.mod h1:l12UJOfibK4Muvg/AbupHuquNV9qSz/ROdTEPg7f2Vk=
|
||||
github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02 h1:0CGxkUgYSCCQftMjsWRGV4RxrGrPE+gjfm/sWSNXesY=
|
||||
github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02/go.mod h1:nQ/FYuITyIGmM7LO9gzt7a9L1FCjxPoBXrc9oSuBEyo=
|
||||
github.com/absmach/mgate v0.4.5 h1:l6RmrEsR9jxkdb9WHUSecmT0HA41TkZZQVffFfUAIfI=
|
||||
github.com/absmach/mgate v0.4.5/go.mod h1:IvRIHZexZPEIAPmmaJF0L5DY2ERjj+GxRGitOW4s6qo=
|
||||
github.com/absmach/mgate v0.4.6-0.20250425104654-79c62d581921 h1:Y0M0jtSbKmfrwLWJaphfiuzg7boKcjVxOt4ViMZ5OV8=
|
||||
github.com/absmach/mgate v0.4.6-0.20250425104654-79c62d581921/go.mod h1:BYazn/DsEeZxJxWZxy/5NiaS/CfWpR/5auYmbq43VwQ=
|
||||
github.com/absmach/senml v1.0.7 h1:XLvpw0qxbP2QhOz7KLM2ZRar+vSCpSG/0o0kEvWx3No=
|
||||
github.com/absmach/senml v1.0.7/go.mod h1:3bRIiNc8hq7l3auMs8gQrpsM5hHy7iDuiLILrf/+MfA=
|
||||
github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ=
|
||||
|
||||
@@ -6,8 +6,10 @@ package api_test
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -54,11 +56,18 @@ func newTargetHTTPServer() *httptest.Server {
|
||||
}
|
||||
|
||||
func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*httptest.Server, error) {
|
||||
ptUrl, _ := url.Parse(targetServer.URL)
|
||||
ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host)
|
||||
config := mgate.Config{
|
||||
Address: "",
|
||||
Target: targetServer.URL,
|
||||
Host: "",
|
||||
Port: "",
|
||||
PathPrefix: "",
|
||||
TargetHost: ptHost,
|
||||
TargetPort: ptPort,
|
||||
TargetProtocol: ptUrl.Scheme,
|
||||
TargetPath: ptUrl.Path,
|
||||
}
|
||||
mp, err := proxy.NewProxy(config, svc, smqlog.NewMock())
|
||||
mp, err := proxy.NewProxy(config, svc, smqlog.NewMock(), []string{}, []string{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -168,7 +177,7 @@ func TestPublish(t *testing.T) {
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: "",
|
||||
status: http.StatusBadGateway,
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
desc: "publish message with basic auth",
|
||||
|
||||
+13
-3
@@ -6,8 +6,10 @@ package sdk_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/mgate"
|
||||
@@ -44,11 +46,19 @@ func setupMessages() (*httptest.Server, *pubsub.PubSub) {
|
||||
mux := api.MakeHandler(smqlog.NewMock(), "")
|
||||
target := httptest.NewServer(mux)
|
||||
|
||||
ptUrl, _ := url.Parse(target.URL)
|
||||
ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host)
|
||||
config := mgate.Config{
|
||||
Address: "",
|
||||
Target: target.URL,
|
||||
Host: "",
|
||||
Port: "",
|
||||
PathPrefix: "",
|
||||
TargetHost: ptHost,
|
||||
TargetPort: ptPort,
|
||||
TargetProtocol: ptUrl.Scheme,
|
||||
TargetPath: ptUrl.Path,
|
||||
}
|
||||
mp, err := proxy.NewProxy(config, handler, smqlog.NewMock())
|
||||
|
||||
mp, err := proxy.NewProxy(config, handler, smqlog.NewMock(), []string{}, []string{"/health", "/metrics"})
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
+18
-10
@@ -20,14 +20,10 @@ import (
|
||||
const chansPrefix = "channels"
|
||||
|
||||
var (
|
||||
// errFailedMessagePublish indicates that message publishing failed.
|
||||
errFailedMessagePublish = errors.New("failed to publish message")
|
||||
|
||||
// ErrFailedSubscription indicates that client couldn't subscribe to specified channel.
|
||||
ErrFailedSubscription = errors.New("failed to subscribe to a channel")
|
||||
|
||||
// errFailedUnsubscribe indicates that client couldn't unsubscribe from specified channel.
|
||||
errFailedUnsubscribe = errors.New("failed to unsubscribe from a channel")
|
||||
ErrFailedSubscribe = errors.New("failed to unsubscribe from topic")
|
||||
|
||||
// ErrEmptyTopic indicate absence of clientKey in the request.
|
||||
ErrEmptyTopic = errors.New("empty topic")
|
||||
@@ -39,7 +35,9 @@ type Service interface {
|
||||
// the channelID for subscription and domainID specifies the domain for authorization.
|
||||
// Subtopic is optional.
|
||||
// If the subscription is successful, nil is returned otherwise error is returned.
|
||||
Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, client *Client) error
|
||||
Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, client *Client) error
|
||||
|
||||
Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error
|
||||
}
|
||||
|
||||
var _ Service = (*adapterService)(nil)
|
||||
@@ -59,7 +57,7 @@ func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.Cha
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *adapterService) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *Client) error {
|
||||
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *Client) error {
|
||||
if chanID == "" || clientKey == "" || domainID == "" {
|
||||
return svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -69,15 +67,13 @@ func (svc *adapterService) Subscribe(ctx context.Context, clientKey, domainID, c
|
||||
return svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
c.id = clientID
|
||||
|
||||
subject := fmt.Sprintf("%s.%s", chansPrefix, chanID)
|
||||
if subtopic != "" {
|
||||
subject = fmt.Sprintf("%s.%s", subject, subtopic)
|
||||
}
|
||||
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: clientID,
|
||||
ID: sessionID,
|
||||
ClientID: clientID,
|
||||
Topic: subject,
|
||||
Handler: c,
|
||||
@@ -89,6 +85,18 @@ func (svc *adapterService) Subscribe(ctx context.Context, clientKey, domainID, c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error {
|
||||
topic := fmt.Sprintf("%s.%s", chansPrefix, chanID)
|
||||
if subtopic != "" {
|
||||
topic = fmt.Sprintf("%s.%s", topic, subtopic)
|
||||
}
|
||||
|
||||
if err := svc.pubsub.Unsubscribe(ctx, sessionID, topic); err != nil {
|
||||
return errors.Wrap(ErrFailedSubscribe, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorize checks if the clientKey is authorized to access the channel
|
||||
// and returns the clientID if it is.
|
||||
func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
|
||||
|
||||
+5
-3
@@ -6,6 +6,7 @@ package ws_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -45,6 +46,7 @@ var (
|
||||
Protocol: protocol,
|
||||
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
|
||||
}
|
||||
sessionID = "sessionID"
|
||||
)
|
||||
|
||||
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) {
|
||||
@@ -58,7 +60,7 @@ func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *c
|
||||
func TestSubscribe(t *testing.T) {
|
||||
svc, pubsub, clients, channels := newService()
|
||||
|
||||
c := ws.NewClient(nil)
|
||||
c := ws.NewClient(slog.Default(), nil, sessionID)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -182,7 +184,7 @@ func TestSubscribe(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
subConfig := messaging.SubscriberConfig{
|
||||
ID: clientID,
|
||||
ID: sessionID,
|
||||
Topic: "channels." + tc.chanID + "." + subTopic,
|
||||
ClientID: clientID,
|
||||
Handler: c,
|
||||
@@ -200,7 +202,7 @@ func TestSubscribe(t *testing.T) {
|
||||
DomainId: tc.domainID,
|
||||
}).Return(tc.authZRes, tc.authZErr)
|
||||
repocall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
|
||||
err := svc.Subscribe(context.Background(), tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
|
||||
err := svc.Subscribe(context.Background(), sessionID, tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
repocall.Unset()
|
||||
clientsCall.Unset()
|
||||
|
||||
+22
-5
@@ -6,14 +6,16 @@ package api_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/mgate"
|
||||
mHttp "github.com/absmach/mgate/pkg/http"
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
"github.com/absmach/mgate/pkg/websockets"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
chmocks "github.com/absmach/supermq/channels/mocks"
|
||||
@@ -55,11 +57,22 @@ func newHTTPServer(svc ws.Service) *httptest.Server {
|
||||
|
||||
func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*httptest.Server, error) {
|
||||
turl := strings.ReplaceAll(targetServer.URL, "http", "ws")
|
||||
mp, err := websockets.NewProxy("", turl, smqlog.NewMock(), svc)
|
||||
ptUrl, _ := url.Parse(turl)
|
||||
ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host)
|
||||
config := mgate.Config{
|
||||
Host: "",
|
||||
Port: "",
|
||||
PathPrefix: "",
|
||||
TargetHost: ptHost,
|
||||
TargetPort: ptPort,
|
||||
TargetProtocol: ptUrl.Scheme,
|
||||
TargetPath: ptUrl.Path,
|
||||
}
|
||||
mp, err := mHttp.NewProxy(config, svc, smqlog.NewMock(), []string{}, []string{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return httptest.NewServer(http.HandlerFunc(mp.Handler)), nil
|
||||
return httptest.NewServer(http.HandlerFunc(mp.ServeHTTP)), nil
|
||||
}
|
||||
|
||||
func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) {
|
||||
@@ -108,8 +121,12 @@ func TestHandshake(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
defer ts.Close()
|
||||
pubsub.On("Subscribe", mock.Anything, mock.Anything).Return(nil)
|
||||
pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil)
|
||||
clients.On("Authenticate", mock.Anything, mock.MatchedBy(func(req *grpcClientsV1.AuthnReq) bool {
|
||||
return req.ClientSecret == clientKey
|
||||
})).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil)
|
||||
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil)
|
||||
authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil)
|
||||
channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil)
|
||||
|
||||
@@ -191,7 +208,7 @@ func TestHandshake(t *testing.T) {
|
||||
subtopic: "",
|
||||
header: true,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusBadGateway,
|
||||
status: http.StatusBadRequest,
|
||||
msg: []byte{},
|
||||
},
|
||||
{
|
||||
|
||||
+35
-6
@@ -5,7 +5,10 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -16,25 +19,51 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
var channelPartRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
|
||||
var (
|
||||
channelPartRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
|
||||
|
||||
func handshake(ctx context.Context, svc ws.Service) http.HandlerFunc {
|
||||
errGenSessionID = errors.New("failed to generate session id")
|
||||
)
|
||||
|
||||
func generateSessionID() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", errors.Wrap(errGenSessionID, err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func handshake(ctx context.Context, svc ws.Service, logger *slog.Logger) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
encodeError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := generateSessionID()
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to generate session id: %s", err.Error()))
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
req.conn = conn
|
||||
client := ws.NewClient(conn)
|
||||
|
||||
if err := svc.Subscribe(ctx, req.clientKey, req.domainID, req.chanID, req.subtopic, client); err != nil {
|
||||
req.conn.Close()
|
||||
client := ws.NewClient(logger, conn, sessionID)
|
||||
|
||||
client.SetCloseHandler(func(code int, text string) error {
|
||||
return svc.Unsubscribe(ctx, sessionID, req.domainID, req.chanID, req.subtopic)
|
||||
})
|
||||
|
||||
go client.Start(ctx)
|
||||
|
||||
if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.chanID, req.subtopic, client); err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+25
-2
@@ -25,10 +25,11 @@ func LoggingMiddleware(svc ws.Service, logger *slog.Logger) ws.Service {
|
||||
|
||||
// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
|
||||
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("session_id", sessionID),
|
||||
slog.String("channel_id", chanID),
|
||||
slog.String("domain_id", domainID),
|
||||
}
|
||||
@@ -43,5 +44,27 @@ func (lm *loggingMiddleware) Subscribe(ctx context.Context, clientKey, domainID,
|
||||
lm.logger.Info("Subscribe completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, c)
|
||||
return lm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("session_id", sessionID),
|
||||
slog.String("channel_id", chanID),
|
||||
slog.String("domain_id", domainID),
|
||||
}
|
||||
if subtopic != "" {
|
||||
args = append(args, "subtopic", subtopic)
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Unsubscribe failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Unsubscribe completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic)
|
||||
}
|
||||
|
||||
+11
-2
@@ -31,11 +31,20 @@ func MetricsMiddleware(svc ws.Service, counter metrics.Counter, latency metrics.
|
||||
}
|
||||
|
||||
// Subscribe instruments Subscribe method with metrics.
|
||||
func (mm *metricsMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *ws.Client) error {
|
||||
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "subscribe").Add(1)
|
||||
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, c)
|
||||
return mm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "unsubscribe").Add(1)
|
||||
mm.latency.With("method", "unsubscribe").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,9 @@
|
||||
|
||||
package api
|
||||
|
||||
import "github.com/gorilla/websocket"
|
||||
|
||||
type connReq struct {
|
||||
clientKey string
|
||||
chanID string
|
||||
domainID string
|
||||
subtopic string
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
+2
-2
@@ -40,8 +40,8 @@ func MakeHandler(ctx context.Context, svc ws.Service, l *slog.Logger, instanceID
|
||||
logger = l
|
||||
|
||||
mux := chi.NewRouter()
|
||||
mux.Get("/m/{domainID}/c/{chanID}", handshake(ctx, svc))
|
||||
mux.Get("/m/{domainID}/c/{chanID}/*", handshake(ctx, svc))
|
||||
mux.Get("/m/{domainID}/c/{chanID}", handshake(ctx, svc, l))
|
||||
mux.Get("/m/{domainID}/c/{chanID}/*", handshake(ctx, svc, l))
|
||||
|
||||
mux.Get("/health", supermq.Health(service, instanceID))
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
+127
-10
@@ -4,22 +4,52 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var (
|
||||
errHandlerBlockedMsgChan = errors.New("message handler msg chan blocked (full)")
|
||||
errHandlerClosedMsgChan = errors.New("message handler closed msg chan")
|
||||
errFailedToWriteMsg = errors.New("failed to write message to connection")
|
||||
errFailedToWritePing = errors.New("failed to write ping to connection")
|
||||
errReadMsg = errors.New("failed to read messages ")
|
||||
)
|
||||
|
||||
const (
|
||||
// Time allowed to write a message to the peer.
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
pingPeriod = 30 * time.Second
|
||||
|
||||
// Time allowed to read the next pong message from the peer.
|
||||
pongWait = 60 * time.Second
|
||||
)
|
||||
|
||||
// Client handles messaging and websocket connection.
|
||||
type Client struct {
|
||||
conn *websocket.Conn
|
||||
id string
|
||||
logger *slog.Logger
|
||||
conn *websocket.Conn
|
||||
id string
|
||||
msg chan *messaging.Message
|
||||
}
|
||||
|
||||
// NewClient returns a new websocket client.
|
||||
func NewClient(c *websocket.Conn) *Client {
|
||||
return &Client{
|
||||
conn: c,
|
||||
id: "",
|
||||
func NewClient(logger *slog.Logger, conn *websocket.Conn, sessionID string) *Client {
|
||||
c := &Client{
|
||||
logger: logger,
|
||||
conn: conn,
|
||||
id: sessionID,
|
||||
msg: make(chan *messaging.Message, 1024),
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Cancel handles the websocket connection after unsubscribing.
|
||||
@@ -32,10 +62,97 @@ func (c *Client) Cancel() error {
|
||||
|
||||
// Handle handles the sending and receiving of messages via the broker.
|
||||
func (c *Client) Handle(msg *messaging.Message) error {
|
||||
// To prevent publisher from receiving its own published message
|
||||
if msg.GetPublisher() == c.id {
|
||||
select {
|
||||
case c.msg <- msg:
|
||||
return nil
|
||||
default:
|
||||
return errHandlerBlockedMsgChan
|
||||
}
|
||||
}
|
||||
|
||||
// CloseHandler will work only if messages are read.
|
||||
func (c *Client) readPump(ctx context.Context, cancel context.CancelFunc) error {
|
||||
defer cancel()
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logger.Debug("read_pump: received context Done")
|
||||
return nil
|
||||
default:
|
||||
msgType, msg, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
c.logger.Debug("read_pump: unexpected close error", slog.String("error", err.Error()))
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(errReadMsg, err)
|
||||
}
|
||||
c.logger.Debug("read_pump: received message ", slog.Int("message_type", msgType), slog.String("message", string(msg)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) writePump(ctx context.Context, cancel context.CancelFunc) error {
|
||||
defer cancel()
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logger.Debug("write_pump: received context Done ")
|
||||
return nil
|
||||
case msg, ok := <-c.msg:
|
||||
if !ok {
|
||||
if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
|
||||
return errors.Wrap(errHandlerClosedMsgChan, err)
|
||||
}
|
||||
return errHandlerClosedMsgChan
|
||||
}
|
||||
if err := c.conn.WriteMessage(websocket.BinaryMessage, msg.GetPayload()); err != nil {
|
||||
return errors.Wrap(errFailedToWriteMsg, err)
|
||||
}
|
||||
case <-ticker.C:
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return errors.Wrap(errFailedToWritePing, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetCloseHandler sets a close handler for the WebSocket connection.
|
||||
func (c *Client) SetCloseHandler(handler func(code int, text string) error) {
|
||||
c.conn.SetCloseHandler(func(code int, text string) error {
|
||||
c.logger.Debug("WebSocket closed", slog.String("session_id", c.id), slog.Int("code", code), slog.String("text", text))
|
||||
if err := handler(code, text); err != nil {
|
||||
c.logger.Warn("Error in close handler", slog.String("error", err.Error()))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) Start(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
return c.readPump(ctx, cancel)
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return c.writePump(ctx, cancel)
|
||||
})
|
||||
|
||||
err := g.Wait()
|
||||
if err != nil {
|
||||
c.logger.Warn("websocket client error", slog.String("session_id", c.id), slog.String("error", err.Error()))
|
||||
}
|
||||
|
||||
return c.conn.WriteMessage(websocket.TextMessage, msg.GetPayload())
|
||||
}
|
||||
|
||||
+5
-2
@@ -4,7 +4,9 @@
|
||||
package ws_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -17,7 +19,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const expectedCount = uint64(1)
|
||||
const expectedCount = uint64(2)
|
||||
|
||||
var (
|
||||
msgChan = make(chan []byte)
|
||||
@@ -61,7 +63,8 @@ func TestHandle(t *testing.T) {
|
||||
}
|
||||
defer wsConn.Close()
|
||||
|
||||
c = ws.NewClient(wsConn)
|
||||
c = ws.NewClient(slog.Default(), wsConn, "sessionID")
|
||||
go c.Start(context.Background())
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
|
||||
+26
-33
@@ -7,11 +7,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mgate "github.com/absmach/mgate/pkg/http"
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
@@ -31,7 +33,6 @@ const protocol = "websocket"
|
||||
// Log message formats.
|
||||
const (
|
||||
LogInfoSubscribed = "subscribed with client_id %s to topics %s"
|
||||
LogInfoUnsubscribed = "unsubscribed client_id %s from topics %s"
|
||||
LogInfoConnected = "connected with client_id %s"
|
||||
LogInfoDisconnected = "disconnected client_id %s and username %s"
|
||||
LogInfoPublished = "published with client_id %s to the topic %s"
|
||||
@@ -39,19 +40,16 @@ const (
|
||||
|
||||
// Error wrappers for MQTT errors.
|
||||
var (
|
||||
errMalformedSubtopic = errors.New("malformed subtopic")
|
||||
errClientNotInitialized = errors.New("client is not initialized")
|
||||
errMalformedTopic = errors.New("malformed topic")
|
||||
errMissingTopicPub = errors.New("failed to publish due to missing topic")
|
||||
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
|
||||
errFailedSubscribe = errors.New("failed to subscribe")
|
||||
errFailedPublish = errors.New("failed to publish")
|
||||
errFailedParseSubtopic = errors.New("failed to parse subtopic")
|
||||
channelRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
|
||||
|
||||
errMalformedSubtopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed subtopic"))
|
||||
errClientNotInitialized = mgate.NewHTTPProxyError(http.StatusInternalServerError, errors.New("client is not initialized"))
|
||||
errMalformedTopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed topic"))
|
||||
errMissingTopicPub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to publish due to missing topic"))
|
||||
errMissingTopicSub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to subscribe due to missing topic"))
|
||||
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
|
||||
)
|
||||
|
||||
var channelRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
|
||||
|
||||
// Event implements events.Event interface.
|
||||
type handler struct {
|
||||
pubsub messaging.PubSub
|
||||
@@ -131,18 +129,19 @@ func (h *handler) Connect(ctx context.Context) error {
|
||||
func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) error {
|
||||
s, ok := session.FromContext(ctx)
|
||||
if !ok {
|
||||
return errors.Wrap(errFailedPublish, errClientNotInitialized)
|
||||
return errClientNotInitialized
|
||||
}
|
||||
|
||||
if len(*payload) == 0 {
|
||||
return errFailedMessagePublish
|
||||
h.logger.Warn("Empty payload, not publishing to broker", slog.String("client_id", s.Username))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Topics are in the format:
|
||||
// m/<domain_id>/c/<channel_id>/<subtopic>/.../ct/<content_type>
|
||||
channelParts := channelRegExp.FindStringSubmatch(*topic)
|
||||
if len(channelParts) < 3 {
|
||||
return errors.Wrap(errFailedPublish, errMalformedTopic)
|
||||
return errMalformedTopic
|
||||
}
|
||||
|
||||
domainID := channelParts[1]
|
||||
@@ -151,12 +150,12 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
|
||||
subtopic, err := parseSubtopic(subtopic)
|
||||
if err != nil {
|
||||
return errors.Wrap(errFailedParseSubtopic, err)
|
||||
return err
|
||||
}
|
||||
|
||||
clientID, clientType, err := h.authAccess(ctx, string(s.Password), *topic, connections.Publish)
|
||||
if err != nil {
|
||||
return errors.Wrap(errFailedPublish, err)
|
||||
return err
|
||||
}
|
||||
|
||||
msg := messaging.Message{
|
||||
@@ -173,7 +172,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
}
|
||||
|
||||
if err := h.pubsub.Publish(ctx, msg.GetChannel(), &msg); err != nil {
|
||||
return errors.Wrap(errFailedPublishToMsgBroker, err)
|
||||
return mgate.NewHTTPProxyError(http.StatusInternalServerError, errors.Wrap(errFailedPublishToMsgBroker, err))
|
||||
}
|
||||
|
||||
h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic))
|
||||
@@ -185,7 +184,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
func (h *handler) Subscribe(ctx context.Context, topics *[]string) error {
|
||||
s, ok := session.FromContext(ctx)
|
||||
if !ok {
|
||||
return errors.Wrap(errFailedSubscribe, errClientNotInitialized)
|
||||
return errClientNotInitialized
|
||||
}
|
||||
h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ",")))
|
||||
return nil
|
||||
@@ -193,12 +192,6 @@ func (h *handler) Subscribe(ctx context.Context, topics *[]string) error {
|
||||
|
||||
// Unsubscribe - after client unsubscribed.
|
||||
func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
|
||||
s, ok := session.FromContext(ctx)
|
||||
if !ok {
|
||||
return errors.Wrap(errFailedUnsubscribe, errClientNotInitialized)
|
||||
}
|
||||
|
||||
h.logger.Info(fmt.Sprintf(LogInfoUnsubscribed, s.ID, strings.Join(*topics, ",")))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -207,7 +200,7 @@ func (h *handler) Disconnect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handler) authAccess(ctx context.Context, token, topic string, msgType connections.ConnType) (string, string, error) {
|
||||
func (h *handler) authAccess(ctx context.Context, token, topic string, msgType connections.ConnType) (string, string, mgate.HTTPProxyError) {
|
||||
authnReq := &grpcClientsV1.AuthnReq{
|
||||
ClientSecret: token,
|
||||
}
|
||||
@@ -217,23 +210,23 @@ func (h *handler) authAccess(ctx context.Context, token, topic string, msgType c
|
||||
|
||||
authnRes, err := h.clients.Authenticate(ctx, authnReq)
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
|
||||
}
|
||||
if !authnRes.GetAuthenticated() {
|
||||
return "", "", svcerr.ErrAuthentication
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
clientType := policies.ClientType
|
||||
clientID := authnRes.GetId()
|
||||
|
||||
// Topics are in the format:
|
||||
// c/<channel_id>/m/<subtopic>/.../ct/<content_type>
|
||||
// m/<domain_id>/c/<channel_id>/<subtopic>/.../ct/<content_type>
|
||||
if !channelRegExp.MatchString(topic) {
|
||||
return "", "", errMalformedTopic
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusBadRequest, errMalformedTopic)
|
||||
}
|
||||
|
||||
channelParts := channelRegExp.FindStringSubmatch(topic)
|
||||
if len(channelParts) < 3 {
|
||||
return "", "", errMalformedTopic
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusBadRequest, errMalformedTopic)
|
||||
}
|
||||
|
||||
domainID := channelParts[1]
|
||||
@@ -248,16 +241,16 @@ func (h *handler) authAccess(ctx context.Context, token, topic string, msgType c
|
||||
}
|
||||
res, err := h.channels.Authorize(ctx, ar)
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
|
||||
}
|
||||
if !res.GetAuthorized() {
|
||||
return "", "", errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
|
||||
return clientID, clientType, nil
|
||||
}
|
||||
|
||||
func parseSubtopic(subtopic string) (string, error) {
|
||||
func parseSubtopic(subtopic string) (string, mgate.HTTPProxyError) {
|
||||
if subtopic == "" {
|
||||
return subtopic, nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
var _ ws.Service = (*tracingMiddleware)(nil)
|
||||
|
||||
const (
|
||||
publishOP = "publish_op"
|
||||
subscribeOP = "subscribe_op"
|
||||
unsubscribeOP = "unsubscribe_op"
|
||||
)
|
||||
@@ -32,9 +31,16 @@ func New(tracer trace.Tracer, svc ws.Service) ws.Service {
|
||||
}
|
||||
|
||||
// Subscribe traces the "Subscribe" operation of the wrapped ws.Service.
|
||||
func (tm *tracingMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, client *ws.Client) error {
|
||||
func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, client *ws.Client) error {
|
||||
ctx, span := tm.tracer.Start(ctx, subscribeOP)
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, client)
|
||||
return tm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, client)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error {
|
||||
ctx, span := tm.tracer.Start(ctx, unsubscribeOP)
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user