NOISSUE - Update mGate version in http and ws adapters (#2825)

Signed-off-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
Arvindh
2025-04-25 17:36:37 +05:30
committed by GitHub
parent cc9fc70d6d
commit 5ea3cbd014
20 changed files with 366 additions and 129 deletions
+17 -13
View File
@@ -43,14 +43,16 @@ import (
)
const (
svcName = "http_adapter"
envPrefix = "SMQ_HTTP_ADAPTER_"
envPrefixClients = "SMQ_CLIENTS_GRPC_"
envPrefixChannels = "SMQ_CHANNELS_GRPC_"
envPrefixAuth = "SMQ_AUTH_GRPC_"
defSvcHTTPPort = "80"
targetHTTPPort = "81"
targetHTTPHost = "http://localhost"
svcName = "http_adapter"
envPrefix = "SMQ_HTTP_ADAPTER_"
envPrefixClients = "SMQ_CLIENTS_GRPC_"
envPrefixChannels = "SMQ_CHANNELS_GRPC_"
envPrefixAuth = "SMQ_AUTH_GRPC_"
defSvcHTTPPort = "80"
targetHTTPProtocol = "http"
targetHTTPHost = "localhost"
targetHTTPPort = "81"
targetHTTPPath = ""
)
type config struct {
@@ -210,9 +212,11 @@ func newService(pub messaging.Publisher, authn smqauthn.Authentication, clients
func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sessionHandler session.Handler) error {
config := mgate.Config{
Address: fmt.Sprintf("%s:%s", "", cfg.Port),
Target: fmt.Sprintf("%s:%s", targetHTTPHost, targetHTTPPort),
PathPrefix: "/",
Port: cfg.Port,
TargetProtocol: targetHTTPProtocol,
TargetHost: targetHTTPHost,
TargetPort: targetHTTPPort,
TargetPath: targetHTTPPath,
}
if cfg.CertFile != "" || cfg.KeyFile != "" {
tlsCert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
@@ -223,7 +227,7 @@ func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sess
Certificates: []tls.Certificate{tlsCert},
}
}
mp, err := mgatehttp.NewProxy(config, sessionHandler, logger)
mp, err := mgatehttp.NewProxy(config, sessionHandler, logger, []string{}, []string{"/health", "/metrics"})
if err != nil {
return err
}
@@ -245,7 +249,7 @@ func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sess
select {
case <-ctx.Done():
logger.Info(fmt.Sprintf("proxy HTTP shutdown at %s", config.Target))
logger.Info(fmt.Sprintf("proxy HTTP shutdown at %s:%s", config.Host, config.Port))
return nil
case err := <-errCh:
return err
+15 -9
View File
@@ -53,6 +53,7 @@ const (
type config struct {
LogLevel string `env:"SMQ_MQTT_ADAPTER_LOG_LEVEL" envDefault:"info"`
MQTTPort string `env:"SMQ_MQTT_ADAPTER_MQTT_PORT" envDefault:"1883"`
MQTTTargetProtocol string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_PROTOCOL" envDefault:"mqtt"`
MQTTTargetHost string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST" envDefault:"localhost"`
MQTTTargetPort string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT" envDefault:"1883"`
MQTTTargetUsername string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_USERNAME" envDefault:""`
@@ -61,6 +62,7 @@ type config struct {
MQTTTargetHealthCheck string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK" envDefault:""`
MQTTQoS uint8 `env:"SMQ_MQTT_ADAPTER_MQTT_QOS" envDefault:"1"`
HTTPPort string `env:"SMQ_MQTT_ADAPTER_WS_PORT" envDefault:"8080"`
HTTPTargetProtocol string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PROTOCOL" envDefault:"http"`
HTTPTargetHost string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_HOST" envDefault:"localhost"`
HTTPTargetPort string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PORT" envDefault:"8080"`
HTTPTargetPath string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PATH" envDefault:"/mqtt"`
@@ -250,10 +252,11 @@ func main() {
func proxyMQTT(ctx context.Context, cfg config, logger *slog.Logger, sessionHandler session.Handler, interceptor session.Interceptor) error {
config := mgate.Config{
Address: fmt.Sprintf(":%s", cfg.MQTTPort),
Target: fmt.Sprintf("%s:%s", cfg.MQTTTargetHost, cfg.MQTTTargetPort),
Port: cfg.MQTTPort,
TargetHost: cfg.MQTTTargetHost,
TargetPort: cfg.MQTTTargetPort,
}
mproxy := mgatemqtt.New(config, sessionHandler, interceptor, logger)
mproxy := mgatemqtt.New(config, sessionHandler, nil, interceptor, logger)
errCh := make(chan error)
go func() {
@@ -262,7 +265,7 @@ func proxyMQTT(ctx context.Context, cfg config, logger *slog.Logger, sessionHand
select {
case <-ctx.Done():
logger.Info(fmt.Sprintf("proxy MQTT shutdown at %s", config.Target))
logger.Info(fmt.Sprintf("proxy MQTT shutdown at %s:%s", config.Host, config.Port))
return nil
case err := <-errCh:
return err
@@ -271,12 +274,15 @@ func proxyMQTT(ctx context.Context, cfg config, logger *slog.Logger, sessionHand
func proxyWS(ctx context.Context, cfg config, logger *slog.Logger, sessionHandler session.Handler, interceptor session.Interceptor) error {
config := mgate.Config{
Address: fmt.Sprintf("%s:%s", "", cfg.HTTPPort),
Target: fmt.Sprintf("ws://%s:%s%s", cfg.HTTPTargetHost, cfg.HTTPTargetPort, cfg.HTTPTargetPath),
PathPrefix: wsPathPrefix,
Port: cfg.HTTPPort,
TargetProtocol: "http",
TargetHost: cfg.HTTPTargetHost,
TargetPort: cfg.HTTPTargetPort,
TargetPath: cfg.HTTPTargetPath,
PathPrefix: wsPathPrefix,
}
wp := websocket.New(config, sessionHandler, interceptor, logger)
wp := websocket.New(config, sessionHandler, nil, interceptor, logger)
http.HandleFunc(wsPathPrefix, wp.ServeHTTP)
errCh := make(chan error)
@@ -287,7 +293,7 @@ func proxyWS(ctx context.Context, cfg config, logger *slog.Logger, sessionHandle
select {
case <-ctx.Done():
logger.Info(fmt.Sprintf("proxy MQTT WS shutdown at %s", config.Target))
logger.Info(fmt.Sprintf("proxy MQTT WS shutdown at %s:%s", config.Host, config.Port))
return nil
case err := <-errCh:
return err
+18 -16
View File
@@ -13,8 +13,9 @@ import (
"os"
chclient "github.com/absmach/callhome/pkg/client"
"github.com/absmach/mgate"
"github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
"github.com/absmach/mgate/pkg/websockets"
"github.com/absmach/supermq"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
@@ -45,8 +46,9 @@ const (
envPrefixChannels = "SMQ_CHANNELS_GRPC_"
envPrefixAuth = "SMQ_AUTH_GRPC_"
defSvcHTTPPort = "8190"
targetWSPort = "8191"
targetWSProtocol = "http"
targetWSHost = "localhost"
targetWSPort = "8191"
)
type config struct {
@@ -184,9 +186,10 @@ func main() {
}
g.Go(func() error {
g.Go(func() error {
return hs.Start()
})
return hs.Start()
})
g.Go(func() error {
handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient)
return proxyWS(ctx, httpServerConfig, targetServerConfig, logger, handler)
})
@@ -210,9 +213,14 @@ func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcC
}
func proxyWS(ctx context.Context, hostConfig, targetConfig server.Config, logger *slog.Logger, handler session.Handler) error {
target := fmt.Sprintf("ws://%s:%s", targetConfig.Host, targetConfig.Port)
address := fmt.Sprintf("%s:%s", hostConfig.Host, hostConfig.Port)
wp, err := websockets.NewProxy(address, target, logger, handler)
config := mgate.Config{
Host: hostConfig.Host,
Port: hostConfig.Port,
TargetProtocol: targetWSProtocol,
TargetHost: targetWSHost,
TargetPort: targetWSPort,
}
wp, err := http.NewProxy(config, handler, logger, []string{}, []string{"/health", "/metrics"})
if err != nil {
return err
}
@@ -220,18 +228,12 @@ func proxyWS(ctx context.Context, hostConfig, targetConfig server.Config, logger
errCh := make(chan error)
go func() {
if hostConfig.CertFile != "" && hostConfig.KeyFile != "" {
logger.Info(fmt.Sprintf("ws-adapter service HTTP server listening at %s:%s with TLS", hostConfig.Host, hostConfig.Port))
errCh <- wp.ListenTLS(hostConfig.CertFile, hostConfig.KeyFile)
} else {
logger.Info(fmt.Sprintf("ws-adapter service HTTP server listening at %s:%s without TLS", hostConfig.Host, hostConfig.Port))
errCh <- wp.Listen()
}
errCh <- wp.Listen(ctx)
}()
select {
case <-ctx.Done():
logger.Info(fmt.Sprintf("proxy MQTT WS shutdown at %s", target))
logger.Info(fmt.Sprintf("ws-adapter service shutdown at %s:%s", hostConfig.Host, hostConfig.Port))
return nil
case err := <-errCh:
return err
+2
View File
@@ -42,11 +42,13 @@ SMQ_MESSAGE_BROKER_URL=${SMQ_NATS_URL}
SMQ_MQTT_BROKER_TYPE=rabbitmq
SMQ_MQTT_BROKER_HEALTH_CHECK=
SMQ_MQTT_ADAPTER_MQTT_QOS=${SMQ_RABBITMQ_MQTT_QOS}
SMQ_MQTT_ADAPTER_MQTT_TARGET_PROTOCOL=mqtt
SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST=${SMQ_MQTT_BROKER_TYPE}
SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT=1883
SMQ_MQTT_ADAPTER_MQTT_TARGET_USERNAME=${SMQ_RABBITMQ_USER}
SMQ_MQTT_ADAPTER_MQTT_TARGET_PASSWORD=${SMQ_RABBITMQ_PASS}
SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK=${SMQ_MQTT_BROKER_HEALTH_CHECK}
SMQ_MQTT_ADAPTER_WS_TARGET_PROTOCOL=http
SMQ_MQTT_ADAPTER_WS_TARGET_HOST=${SMQ_MQTT_BROKER_TYPE}
SMQ_MQTT_ADAPTER_WS_TARGET_PORT=${SMQ_RABBITMQ_WS_PORT}
SMQ_MQTT_ADAPTER_WS_TARGET_PATH=${SMQ_RABBITMQ_WS_TARGET_PATH}
+1 -1
View File
@@ -6,7 +6,7 @@ require (
github.com/0x6flab/namegenerator v1.4.0
github.com/absmach/callhome v0.14.0
github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02
github.com/absmach/mgate v0.4.5
github.com/absmach/mgate v0.4.6-0.20250425104654-79c62d581921
github.com/absmach/senml v1.0.7
github.com/authzed/authzed-go v1.4.0
github.com/authzed/grpcutil v0.0.0-20250221190651-1985b19b35b8
+2 -2
View File
@@ -21,8 +21,8 @@ github.com/absmach/callhome v0.14.0 h1:zB4tIZJ1YUmZ1VGHFPfMA/Lo6/Mv19y2dvoOiXj2B
github.com/absmach/callhome v0.14.0/go.mod h1:l12UJOfibK4Muvg/AbupHuquNV9qSz/ROdTEPg7f2Vk=
github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02 h1:0CGxkUgYSCCQftMjsWRGV4RxrGrPE+gjfm/sWSNXesY=
github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02/go.mod h1:nQ/FYuITyIGmM7LO9gzt7a9L1FCjxPoBXrc9oSuBEyo=
github.com/absmach/mgate v0.4.5 h1:l6RmrEsR9jxkdb9WHUSecmT0HA41TkZZQVffFfUAIfI=
github.com/absmach/mgate v0.4.5/go.mod h1:IvRIHZexZPEIAPmmaJF0L5DY2ERjj+GxRGitOW4s6qo=
github.com/absmach/mgate v0.4.6-0.20250425104654-79c62d581921 h1:Y0M0jtSbKmfrwLWJaphfiuzg7boKcjVxOt4ViMZ5OV8=
github.com/absmach/mgate v0.4.6-0.20250425104654-79c62d581921/go.mod h1:BYazn/DsEeZxJxWZxy/5NiaS/CfWpR/5auYmbq43VwQ=
github.com/absmach/senml v1.0.7 h1:XLvpw0qxbP2QhOz7KLM2ZRar+vSCpSG/0o0kEvWx3No=
github.com/absmach/senml v1.0.7/go.mod h1:3bRIiNc8hq7l3auMs8gQrpsM5hHy7iDuiLILrf/+MfA=
github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ=
+13 -4
View File
@@ -6,8 +6,10 @@ package api_test
import (
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
@@ -54,11 +56,18 @@ func newTargetHTTPServer() *httptest.Server {
}
func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*httptest.Server, error) {
ptUrl, _ := url.Parse(targetServer.URL)
ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host)
config := mgate.Config{
Address: "",
Target: targetServer.URL,
Host: "",
Port: "",
PathPrefix: "",
TargetHost: ptHost,
TargetPort: ptPort,
TargetProtocol: ptUrl.Scheme,
TargetPath: ptUrl.Path,
}
mp, err := proxy.NewProxy(config, svc, smqlog.NewMock())
mp, err := proxy.NewProxy(config, svc, smqlog.NewMock(), []string{}, []string{})
if err != nil {
return nil, err
}
@@ -168,7 +177,7 @@ func TestPublish(t *testing.T) {
msg: msg,
contentType: ctSenmlJSON,
key: "",
status: http.StatusBadGateway,
status: http.StatusBadRequest,
},
{
desc: "publish message with basic auth",
+13 -3
View File
@@ -6,8 +6,10 @@ package sdk_test
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/absmach/mgate"
@@ -44,11 +46,19 @@ func setupMessages() (*httptest.Server, *pubsub.PubSub) {
mux := api.MakeHandler(smqlog.NewMock(), "")
target := httptest.NewServer(mux)
ptUrl, _ := url.Parse(target.URL)
ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host)
config := mgate.Config{
Address: "",
Target: target.URL,
Host: "",
Port: "",
PathPrefix: "",
TargetHost: ptHost,
TargetPort: ptPort,
TargetProtocol: ptUrl.Scheme,
TargetPath: ptUrl.Path,
}
mp, err := proxy.NewProxy(config, handler, smqlog.NewMock())
mp, err := proxy.NewProxy(config, handler, smqlog.NewMock(), []string{}, []string{"/health", "/metrics"})
if err != nil {
return nil, nil
}
+18 -10
View File
@@ -20,14 +20,10 @@ import (
const chansPrefix = "channels"
var (
// errFailedMessagePublish indicates that message publishing failed.
errFailedMessagePublish = errors.New("failed to publish message")
// ErrFailedSubscription indicates that client couldn't subscribe to specified channel.
ErrFailedSubscription = errors.New("failed to subscribe to a channel")
// errFailedUnsubscribe indicates that client couldn't unsubscribe from specified channel.
errFailedUnsubscribe = errors.New("failed to unsubscribe from a channel")
ErrFailedSubscribe = errors.New("failed to unsubscribe from topic")
// ErrEmptyTopic indicate absence of clientKey in the request.
ErrEmptyTopic = errors.New("empty topic")
@@ -39,7 +35,9 @@ type Service interface {
// the channelID for subscription and domainID specifies the domain for authorization.
// Subtopic is optional.
// If the subscription is successful, nil is returned otherwise error is returned.
Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, client *Client) error
Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, client *Client) error
Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error
}
var _ Service = (*adapterService)(nil)
@@ -59,7 +57,7 @@ func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.Cha
}
}
func (svc *adapterService) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *Client) error {
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *Client) error {
if chanID == "" || clientKey == "" || domainID == "" {
return svcerr.ErrAuthentication
}
@@ -69,15 +67,13 @@ func (svc *adapterService) Subscribe(ctx context.Context, clientKey, domainID, c
return svcerr.ErrAuthorization
}
c.id = clientID
subject := fmt.Sprintf("%s.%s", chansPrefix, chanID)
if subtopic != "" {
subject = fmt.Sprintf("%s.%s", subject, subtopic)
}
subCfg := messaging.SubscriberConfig{
ID: clientID,
ID: sessionID,
ClientID: clientID,
Topic: subject,
Handler: c,
@@ -89,6 +85,18 @@ func (svc *adapterService) Subscribe(ctx context.Context, clientKey, domainID, c
return nil
}
func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error {
topic := fmt.Sprintf("%s.%s", chansPrefix, chanID)
if subtopic != "" {
topic = fmt.Sprintf("%s.%s", topic, subtopic)
}
if err := svc.pubsub.Unsubscribe(ctx, sessionID, topic); err != nil {
return errors.Wrap(ErrFailedSubscribe, err)
}
return nil
}
// authorize checks if the clientKey is authorized to access the channel
// and returns the clientID if it is.
func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
+5 -3
View File
@@ -6,6 +6,7 @@ package ws_test
import (
"context"
"fmt"
"log/slog"
"strings"
"testing"
@@ -45,6 +46,7 @@ var (
Protocol: protocol,
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
}
sessionID = "sessionID"
)
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) {
@@ -58,7 +60,7 @@ func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *c
func TestSubscribe(t *testing.T) {
svc, pubsub, clients, channels := newService()
c := ws.NewClient(nil)
c := ws.NewClient(slog.Default(), nil, sessionID)
cases := []struct {
desc string
@@ -182,7 +184,7 @@ func TestSubscribe(t *testing.T) {
for _, tc := range cases {
subConfig := messaging.SubscriberConfig{
ID: clientID,
ID: sessionID,
Topic: "channels." + tc.chanID + "." + subTopic,
ClientID: clientID,
Handler: c,
@@ -200,7 +202,7 @@ func TestSubscribe(t *testing.T) {
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
repocall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
err := svc.Subscribe(context.Background(), tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
err := svc.Subscribe(context.Background(), sessionID, tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repocall.Unset()
clientsCall.Unset()
+22 -5
View File
@@ -6,14 +6,16 @@ package api_test
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/absmach/mgate"
mHttp "github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
"github.com/absmach/mgate/pkg/websockets"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
chmocks "github.com/absmach/supermq/channels/mocks"
@@ -55,11 +57,22 @@ func newHTTPServer(svc ws.Service) *httptest.Server {
func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*httptest.Server, error) {
turl := strings.ReplaceAll(targetServer.URL, "http", "ws")
mp, err := websockets.NewProxy("", turl, smqlog.NewMock(), svc)
ptUrl, _ := url.Parse(turl)
ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host)
config := mgate.Config{
Host: "",
Port: "",
PathPrefix: "",
TargetHost: ptHost,
TargetPort: ptPort,
TargetProtocol: ptUrl.Scheme,
TargetPath: ptUrl.Path,
}
mp, err := mHttp.NewProxy(config, svc, smqlog.NewMock(), []string{}, []string{})
if err != nil {
return nil, err
}
return httptest.NewServer(http.HandlerFunc(mp.Handler)), nil
return httptest.NewServer(http.HandlerFunc(mp.ServeHTTP)), nil
}
func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) {
@@ -108,8 +121,12 @@ func TestHandshake(t *testing.T) {
require.Nil(t, err)
defer ts.Close()
pubsub.On("Subscribe", mock.Anything, mock.Anything).Return(nil)
pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil)
pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil)
clients.On("Authenticate", mock.Anything, mock.MatchedBy(func(req *grpcClientsV1.AuthnReq) bool {
return req.ClientSecret == clientKey
})).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil)
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil)
authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil)
channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil)
@@ -191,7 +208,7 @@ func TestHandshake(t *testing.T) {
subtopic: "",
header: true,
clientKey: clientKey,
status: http.StatusBadGateway,
status: http.StatusBadRequest,
msg: []byte{},
},
{
+35 -6
View File
@@ -5,7 +5,10 @@ package api
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"log/slog"
"net/http"
"net/url"
"regexp"
@@ -16,25 +19,51 @@ import (
"github.com/go-chi/chi/v5"
)
var channelPartRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
var (
channelPartRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
func handshake(ctx context.Context, svc ws.Service) http.HandlerFunc {
errGenSessionID = errors.New("failed to generate session id")
)
func generateSessionID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", errors.Wrap(errGenSessionID, err)
}
return hex.EncodeToString(b), nil
}
func handshake(ctx context.Context, svc ws.Service, logger *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r)
if err != nil {
encodeError(w, err)
return
}
sessionID, err := generateSessionID()
if err != nil {
logger.Warn(fmt.Sprintf("Failed to generate session id: %s", err.Error()))
http.Error(w, "", http.StatusInternalServerError)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error()))
return
}
req.conn = conn
client := ws.NewClient(conn)
if err := svc.Subscribe(ctx, req.clientKey, req.domainID, req.chanID, req.subtopic, client); err != nil {
req.conn.Close()
client := ws.NewClient(logger, conn, sessionID)
client.SetCloseHandler(func(code int, text string) error {
return svc.Unsubscribe(ctx, sessionID, req.domainID, req.chanID, req.subtopic)
})
go client.Start(ctx)
if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.chanID, req.subtopic, client); err != nil {
conn.Close()
return
}
+25 -2
View File
@@ -25,10 +25,11 @@ func LoggingMiddleware(svc ws.Service, logger *slog.Logger) ws.Service {
// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("session_id", sessionID),
slog.String("channel_id", chanID),
slog.String("domain_id", domainID),
}
@@ -43,5 +44,27 @@ func (lm *loggingMiddleware) Subscribe(ctx context.Context, clientKey, domainID,
lm.logger.Info("Subscribe completed successfully", args...)
}(time.Now())
return lm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, c)
return lm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c)
}
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("session_id", sessionID),
slog.String("channel_id", chanID),
slog.String("domain_id", domainID),
}
if subtopic != "" {
args = append(args, "subtopic", subtopic)
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Unsubscribe failed", args...)
return
}
lm.logger.Info("Unsubscribe completed successfully", args...)
}(time.Now())
return lm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic)
}
+11 -2
View File
@@ -31,11 +31,20 @@ func MetricsMiddleware(svc ws.Service, counter metrics.Counter, latency metrics.
}
// Subscribe instruments Subscribe method with metrics.
func (mm *metricsMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *ws.Client) error {
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) error {
defer func(begin time.Time) {
mm.counter.With("method", "subscribe").Add(1)
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, c)
return mm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c)
}
func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error {
defer func(begin time.Time) {
mm.counter.With("method", "unsubscribe").Add(1)
mm.latency.With("method", "unsubscribe").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic)
}
-3
View File
@@ -3,12 +3,9 @@
package api
import "github.com/gorilla/websocket"
type connReq struct {
clientKey string
chanID string
domainID string
subtopic string
conn *websocket.Conn
}
+2 -2
View File
@@ -40,8 +40,8 @@ func MakeHandler(ctx context.Context, svc ws.Service, l *slog.Logger, instanceID
logger = l
mux := chi.NewRouter()
mux.Get("/m/{domainID}/c/{chanID}", handshake(ctx, svc))
mux.Get("/m/{domainID}/c/{chanID}/*", handshake(ctx, svc))
mux.Get("/m/{domainID}/c/{chanID}", handshake(ctx, svc, l))
mux.Get("/m/{domainID}/c/{chanID}/*", handshake(ctx, svc, l))
mux.Get("/health", supermq.Health(service, instanceID))
mux.Handle("/metrics", promhttp.Handler())
+127 -10
View File
@@ -4,22 +4,52 @@
package ws
import (
"context"
"log/slog"
"time"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/messaging"
"github.com/gorilla/websocket"
"golang.org/x/sync/errgroup"
)
var (
errHandlerBlockedMsgChan = errors.New("message handler msg chan blocked (full)")
errHandlerClosedMsgChan = errors.New("message handler closed msg chan")
errFailedToWriteMsg = errors.New("failed to write message to connection")
errFailedToWritePing = errors.New("failed to write ping to connection")
errReadMsg = errors.New("failed to read messages ")
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = 30 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
)
// Client handles messaging and websocket connection.
type Client struct {
conn *websocket.Conn
id string
logger *slog.Logger
conn *websocket.Conn
id string
msg chan *messaging.Message
}
// NewClient returns a new websocket client.
func NewClient(c *websocket.Conn) *Client {
return &Client{
conn: c,
id: "",
func NewClient(logger *slog.Logger, conn *websocket.Conn, sessionID string) *Client {
c := &Client{
logger: logger,
conn: conn,
id: sessionID,
msg: make(chan *messaging.Message, 1024),
}
return c
}
// Cancel handles the websocket connection after unsubscribing.
@@ -32,10 +62,97 @@ func (c *Client) Cancel() error {
// Handle handles the sending and receiving of messages via the broker.
func (c *Client) Handle(msg *messaging.Message) error {
// To prevent publisher from receiving its own published message
if msg.GetPublisher() == c.id {
select {
case c.msg <- msg:
return nil
default:
return errHandlerBlockedMsgChan
}
}
// CloseHandler will work only if messages are read.
func (c *Client) readPump(ctx context.Context, cancel context.CancelFunc) error {
defer cancel()
c.conn.SetPongHandler(func(string) error {
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
return err
}
return nil
})
for {
select {
case <-ctx.Done():
c.logger.Debug("read_pump: received context Done")
return nil
default:
msgType, msg, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
c.logger.Debug("read_pump: unexpected close error", slog.String("error", err.Error()))
return nil
}
return errors.Wrap(errReadMsg, err)
}
c.logger.Debug("read_pump: received message ", slog.Int("message_type", msgType), slog.String("message", string(msg)))
}
}
}
func (c *Client) writePump(ctx context.Context, cancel context.CancelFunc) error {
defer cancel()
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
return err
}
for {
select {
case <-ctx.Done():
c.logger.Debug("write_pump: received context Done ")
return nil
case msg, ok := <-c.msg:
if !ok {
if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
return errors.Wrap(errHandlerClosedMsgChan, err)
}
return errHandlerClosedMsgChan
}
if err := c.conn.WriteMessage(websocket.BinaryMessage, msg.GetPayload()); err != nil {
return errors.Wrap(errFailedToWriteMsg, err)
}
case <-ticker.C:
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return errors.Wrap(errFailedToWritePing, err)
}
}
}
}
// SetCloseHandler sets a close handler for the WebSocket connection.
func (c *Client) SetCloseHandler(handler func(code int, text string) error) {
c.conn.SetCloseHandler(func(code int, text string) error {
c.logger.Debug("WebSocket closed", slog.String("session_id", c.id), slog.Int("code", code), slog.String("text", text))
if err := handler(code, text); err != nil {
c.logger.Warn("Error in close handler", slog.String("error", err.Error()))
}
return nil
})
}
func (c *Client) Start(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return c.readPump(ctx, cancel)
})
g.Go(func() error {
return c.writePump(ctx, cancel)
})
err := g.Wait()
if err != nil {
c.logger.Warn("websocket client error", slog.String("session_id", c.id), slog.String("error", err.Error()))
}
return c.conn.WriteMessage(websocket.TextMessage, msg.GetPayload())
}
+5 -2
View File
@@ -4,7 +4,9 @@
package ws_test
import (
"context"
"fmt"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
@@ -17,7 +19,7 @@ import (
"github.com/stretchr/testify/assert"
)
const expectedCount = uint64(1)
const expectedCount = uint64(2)
var (
msgChan = make(chan []byte)
@@ -61,7 +63,8 @@ func TestHandle(t *testing.T) {
}
defer wsConn.Close()
c = ws.NewClient(wsConn)
c = ws.NewClient(slog.Default(), wsConn, "sessionID")
go c.Start(context.Background())
cases := []struct {
desc string
+26 -33
View File
@@ -7,11 +7,13 @@ import (
"context"
"fmt"
"log/slog"
"net/http"
"net/url"
"regexp"
"strings"
"time"
mgate "github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
@@ -31,7 +33,6 @@ const protocol = "websocket"
// Log message formats.
const (
LogInfoSubscribed = "subscribed with client_id %s to topics %s"
LogInfoUnsubscribed = "unsubscribed client_id %s from topics %s"
LogInfoConnected = "connected with client_id %s"
LogInfoDisconnected = "disconnected client_id %s and username %s"
LogInfoPublished = "published with client_id %s to the topic %s"
@@ -39,19 +40,16 @@ const (
// Error wrappers for MQTT errors.
var (
errMalformedSubtopic = errors.New("malformed subtopic")
errClientNotInitialized = errors.New("client is not initialized")
errMalformedTopic = errors.New("malformed topic")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
errFailedSubscribe = errors.New("failed to subscribe")
errFailedPublish = errors.New("failed to publish")
errFailedParseSubtopic = errors.New("failed to parse subtopic")
channelRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
errMalformedSubtopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed subtopic"))
errClientNotInitialized = mgate.NewHTTPProxyError(http.StatusInternalServerError, errors.New("client is not initialized"))
errMalformedTopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed topic"))
errMissingTopicPub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to publish due to missing topic"))
errMissingTopicSub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to subscribe due to missing topic"))
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
)
var channelRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
// Event implements events.Event interface.
type handler struct {
pubsub messaging.PubSub
@@ -131,18 +129,19 @@ func (h *handler) Connect(ctx context.Context) error {
func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) error {
s, ok := session.FromContext(ctx)
if !ok {
return errors.Wrap(errFailedPublish, errClientNotInitialized)
return errClientNotInitialized
}
if len(*payload) == 0 {
return errFailedMessagePublish
h.logger.Warn("Empty payload, not publishing to broker", slog.String("client_id", s.Username))
return nil
}
// Topics are in the format:
// m/<domain_id>/c/<channel_id>/<subtopic>/.../ct/<content_type>
channelParts := channelRegExp.FindStringSubmatch(*topic)
if len(channelParts) < 3 {
return errors.Wrap(errFailedPublish, errMalformedTopic)
return errMalformedTopic
}
domainID := channelParts[1]
@@ -151,12 +150,12 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
subtopic, err := parseSubtopic(subtopic)
if err != nil {
return errors.Wrap(errFailedParseSubtopic, err)
return err
}
clientID, clientType, err := h.authAccess(ctx, string(s.Password), *topic, connections.Publish)
if err != nil {
return errors.Wrap(errFailedPublish, err)
return err
}
msg := messaging.Message{
@@ -173,7 +172,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
}
if err := h.pubsub.Publish(ctx, msg.GetChannel(), &msg); err != nil {
return errors.Wrap(errFailedPublishToMsgBroker, err)
return mgate.NewHTTPProxyError(http.StatusInternalServerError, errors.Wrap(errFailedPublishToMsgBroker, err))
}
h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic))
@@ -185,7 +184,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
func (h *handler) Subscribe(ctx context.Context, topics *[]string) error {
s, ok := session.FromContext(ctx)
if !ok {
return errors.Wrap(errFailedSubscribe, errClientNotInitialized)
return errClientNotInitialized
}
h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ",")))
return nil
@@ -193,12 +192,6 @@ func (h *handler) Subscribe(ctx context.Context, topics *[]string) error {
// Unsubscribe - after client unsubscribed.
func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
s, ok := session.FromContext(ctx)
if !ok {
return errors.Wrap(errFailedUnsubscribe, errClientNotInitialized)
}
h.logger.Info(fmt.Sprintf(LogInfoUnsubscribed, s.ID, strings.Join(*topics, ",")))
return nil
}
@@ -207,7 +200,7 @@ func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func (h *handler) authAccess(ctx context.Context, token, topic string, msgType connections.ConnType) (string, string, error) {
func (h *handler) authAccess(ctx context.Context, token, topic string, msgType connections.ConnType) (string, string, mgate.HTTPProxyError) {
authnReq := &grpcClientsV1.AuthnReq{
ClientSecret: token,
}
@@ -217,23 +210,23 @@ func (h *handler) authAccess(ctx context.Context, token, topic string, msgType c
authnRes, err := h.clients.Authenticate(ctx, authnReq)
if err != nil {
return "", "", errors.Wrap(svcerr.ErrAuthentication, err)
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
if !authnRes.GetAuthenticated() {
return "", "", svcerr.ErrAuthentication
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
clientType := policies.ClientType
clientID := authnRes.GetId()
// Topics are in the format:
// c/<channel_id>/m/<subtopic>/.../ct/<content_type>
// m/<domain_id>/c/<channel_id>/<subtopic>/.../ct/<content_type>
if !channelRegExp.MatchString(topic) {
return "", "", errMalformedTopic
return "", "", mgate.NewHTTPProxyError(http.StatusBadRequest, errMalformedTopic)
}
channelParts := channelRegExp.FindStringSubmatch(topic)
if len(channelParts) < 3 {
return "", "", errMalformedTopic
return "", "", mgate.NewHTTPProxyError(http.StatusBadRequest, errMalformedTopic)
}
domainID := channelParts[1]
@@ -248,16 +241,16 @@ func (h *handler) authAccess(ctx context.Context, token, topic string, msgType c
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return "", "", errors.Wrap(svcerr.ErrAuthorization, err)
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
if !res.GetAuthorized() {
return "", "", errors.Wrap(svcerr.ErrAuthorization, err)
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
return clientID, clientType, nil
}
func parseSubtopic(subtopic string) (string, error) {
func parseSubtopic(subtopic string) (string, mgate.HTTPProxyError) {
if subtopic == "" {
return subtopic, nil
}
+9 -3
View File
@@ -13,7 +13,6 @@ import (
var _ ws.Service = (*tracingMiddleware)(nil)
const (
publishOP = "publish_op"
subscribeOP = "subscribe_op"
unsubscribeOP = "unsubscribe_op"
)
@@ -32,9 +31,16 @@ func New(tracer trace.Tracer, svc ws.Service) ws.Service {
}
// Subscribe traces the "Subscribe" operation of the wrapped ws.Service.
func (tm *tracingMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, client *ws.Client) error {
func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, client *ws.Client) error {
ctx, span := tm.tracer.Start(ctx, subscribeOP)
defer span.End()
return tm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, client)
return tm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, client)
}
func (tm *tracingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error {
ctx, span := tm.tracer.Start(ctx, unsubscribeOP)
defer span.End()
return tm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic)
}