SMQ-2800 - Add WebSocket support to HTTP adapter (#2937)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
Signed-off-by: Arvindh <arvindh91@gmail.com>
Co-authored-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
Felix Gateru
2025-12-31 12:43:52 +03:00
committed by GitHub
parent a526a2ccd5
commit 67c28ff134
40 changed files with 1427 additions and 2663 deletions
+2 -2
View File
@@ -117,7 +117,7 @@ jobs:
run: make all -j $(nproc) && make dockers_dev -j $(nproc)
- name: Start containers
run: make run up args="-d" && make run_addons up args="-d"
run: make run_latest up args="-d" && make run_addons up args="-d"
- name: Wait for services to be ready
run: |
@@ -216,4 +216,4 @@ jobs:
- name: Stop containers
if: always()
run: make run down args="-v" && make run_addons down args="-v"
run: make run_latest down args="-v" && make run_addons down args="-v"
+1 -10
View File
@@ -240,14 +240,6 @@ jobs:
- "pkg/uuid/**"
- "pkg/events/**"
ws:
- "ws/**"
- "cmd/ws/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "clients/**"
- "pkg/messaging/**"
notifications:
- "notifications/**"
- "cmd/notifications/**"
@@ -272,7 +264,7 @@ jobs:
if [[ "${{ steps.changes.outputs.workflow }}" == "true" || "${{ steps.changes.outputs.pkg-errors }}" == "true" ]]; then
# If workflow or pkg/errors changed, test everything
modules=("auth" "channels" "cli" "clients" "coap" "domains" "groups" "http" "internal" "journal" "logger" "mqtt" "pkg-errors" "pkg-events" "pkg-grpcclient" "pkg-messaging" "pkg-sdk" "pkg-transformers" "pkg-ulid" "pkg-uuid" "users" "ws" "notifications" "api" "consumers" "readers")
modules=("auth" "channels" "cli" "clients" "coap" "domains" "groups" "http" "internal" "journal" "logger" "mqtt" "pkg-errors" "pkg-events" "pkg-grpcclient" "pkg-messaging" "pkg-sdk" "pkg-transformers" "pkg-ulid" "pkg-uuid" "users" "notifications" "api" "consumers" "readers")
else
# Add only changed modules
[[ "${{ steps.changes.outputs.auth }}" == "true" ]] && modules+=("auth")
@@ -296,7 +288,6 @@ jobs:
[[ "${{ steps.changes.outputs.pkg-ulid }}" == "true" ]] && modules+=("pkg-ulid")
[[ "${{ steps.changes.outputs.pkg-uuid }}" == "true" ]] && modules+=("pkg-uuid")
[[ "${{ steps.changes.outputs.users }}" == "true" ]] && modules+=("users")
[[ "${{ steps.changes.outputs.ws }}" == "true" ]] && modules+=("ws")
[[ "${{ steps.changes.outputs.notifications }}" == "true" ]] && modules+=("notifications")
[[ "${{ steps.changes.outputs.api }}" == "true" ]] && modules+=("api")
[[ "${{ steps.changes.outputs.consumers }}" == "true" ]] && modules+=("consumers")
+1 -1
View File
@@ -3,7 +3,7 @@
SMQ_DOCKER_IMAGE_NAME_PREFIX ?= supermq
BUILD_DIR ?= build
SERVICES = auth users clients groups channels domains http coap ws cli mqtt journal notifications
SERVICES = auth users clients groups channels domains http coap cli mqtt journal notifications
TEST_API_SERVICES = journal auth certs http clients users channels groups domains
TEST_API = $(addprefix test_api_,$(TEST_API_SERVICES))
DOCKERS = $(addprefix docker_,$(SERVICES))
+1 -1
View File
@@ -29,7 +29,7 @@ servers:
default: localhost
port:
description: SuperMQ WebSocket Adapter port
default: '8186'
default: '8008'
channels:
'm/{domainPrefix}/c/{channelPrefix}/{subtopic}':
+2
View File
@@ -47,6 +47,8 @@ paths:
description: Message discarded due to its malformed content.
"401":
description: Missing or invalid access token provided.
"403":
description: Access denied to the requested resource.
"404":
description: Message discarded due to invalid channel id.
"415":
+28 -14
View File
@@ -25,6 +25,7 @@ import (
"github.com/absmach/supermq/auth"
adapter "github.com/absmach/supermq/http"
httpapi "github.com/absmach/supermq/http/api"
"github.com/absmach/supermq/http/middleware"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
@@ -209,31 +210,34 @@ func main() {
}()
tracer := tp.Tracer(svcName)
pub, err := brokers.NewPublisher(ctx, cfg.BrokerURL)
nps, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger)
if err != nil {
logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err))
logger.Error(fmt.Sprintf("Failed to connect to message broker: %s", err))
exitCode = 1
return
}
defer pub.Close()
pub = brokerstracing.NewPublisher(httpServerConfig, tracer, pub)
defer nps.Close()
nps = brokerstracing.NewPubSub(httpServerConfig, tracer, nps)
pub, err = msgevents.NewPublisherMiddleware(ctx, pub, cfg.ESURL)
nps, err = msgevents.NewPubSubMiddleware(ctx, nps, cfg.ESURL)
if err != nil {
logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err))
exitCode = 1
return
}
svc, err := newService(pub, authn, cacheConfig, clientsClient, channelsClient, domainsClient, logger, tracer)
resolver := messaging.NewTopicResolver(channelsClient, domainsClient)
handler, err := newHandler(nps, authn, cacheConfig, clientsClient, channelsClient, domainsClient, logger, tracer)
if err != nil {
logger.Error(fmt.Sprintf("failed to create service: %s", err))
exitCode = 1
return
}
svc := newService(clientsClient, channelsClient, authn, nps, logger, tracer)
targetServerCfg := server.Config{Port: targetHTTPPort}
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(logger, cfg.InstanceID), logger)
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
@@ -245,7 +249,7 @@ func main() {
})
g.Go(func() error {
return proxyHTTP(ctx, httpServerConfig, logger, svc)
return proxyHTTP(ctx, httpServerConfig, logger, handler)
})
g.Go(func() error {
@@ -257,17 +261,27 @@ func main() {
}
}
func newService(pub messaging.Publisher, authn smqauthn.Authentication, cacheCfg messaging.CacheConfig, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger, tracer trace.Tracer) (session.Handler, error) {
func newHandler(pubsub messaging.PubSub, authn smqauthn.Authentication, cacheCfg messaging.CacheConfig, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger, tracer trace.Tracer) (session.Handler, error) {
parser, err := messaging.NewTopicParser(cacheCfg, channels, domains)
if err != nil {
return nil, err
}
svc := adapter.NewHandler(pub, authn, clients, channels, parser, logger)
svc = handler.NewTracing(tracer, svc)
svc = handler.NewLogging(svc, logger)
h := adapter.NewHandler(pubsub, logger, authn, clients, channels, parser)
h = handler.NewTracing(tracer, h)
h = handler.NewLogging(h, logger)
counter, latency := prometheus.MakeMetrics(svcName, "handler")
h = handler.NewMetrics(h, counter, latency)
return h, nil
}
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) adapter.Service {
svc := adapter.NewService(clientsClient, channels, authn, nps)
svc = middleware.NewTracing(tracer, svc)
svc = middleware.NewLogging(svc, logger)
counter, latency := prometheus.MakeMetrics(svcName, "api")
svc = handler.NewMetrics(svc, counter, latency)
return svc, nil
svc = middleware.NewMetrics(svc, counter, latency)
return svc
}
func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sessionHandler session.Handler) error {
-299
View File
@@ -1,299 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package main contains websocket-adapter main function to start the websocket-adapter service.
package main
import (
"context"
"fmt"
"log"
"log/slog"
"net/url"
"os"
chclient "github.com/absmach/callhome/pkg/client"
"github.com/absmach/mgate"
"github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
"github.com/absmach/supermq"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
"github.com/absmach/supermq/auth"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
"github.com/absmach/supermq/pkg/grpcclient"
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/brokers"
brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing"
msgevents "github.com/absmach/supermq/pkg/messaging/events"
"github.com/absmach/supermq/pkg/prometheus"
"github.com/absmach/supermq/pkg/server"
httpserver "github.com/absmach/supermq/pkg/server/http"
"github.com/absmach/supermq/pkg/uuid"
"github.com/absmach/supermq/ws"
httpapi "github.com/absmach/supermq/ws/api"
"github.com/absmach/supermq/ws/middleware"
"github.com/caarlos0/env/v11"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
)
const (
svcName = "ws-adapter"
envPrefixHTTP = "SMQ_WS_ADAPTER_HTTP_"
envPrefixCache = "SMQ_WS_ADAPTER_CACHE_"
envPrefixClients = "SMQ_CLIENTS_GRPC_"
envPrefixChannels = "SMQ_CHANNELS_GRPC_"
envPrefixAuth = "SMQ_AUTH_GRPC_"
envPrefixDomains = "SMQ_DOMAINS_GRPC_"
defSvcHTTPPort = "8190"
targetWSProtocol = "http"
targetWSHost = "localhost"
targetWSPort = "8191"
)
type config struct {
LogLevel string `env:"SMQ_WS_ADAPTER_LOG_LEVEL" envDefault:"info"`
BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_WS_ADAPTER_INSTANCE_ID" envDefault:""`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"`
JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"`
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)
cfg := config{}
if err := env.Parse(&cfg); err != nil {
log.Fatalf("failed to load %s configuration : %s", svcName, err)
}
logger, err := smqlog.New(os.Stdout, cfg.LogLevel)
if err != nil {
log.Fatalf("failed to init logger: %s", err.Error())
}
var exitCode int
defer smqlog.ExitWithError(&exitCode)
if cfg.InstanceID == "" {
if cfg.InstanceID, err = uuid.New().ID(); err != nil {
logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err))
exitCode = 1
return
}
}
httpServerConfig := server.Config{Port: defSvcHTTPPort}
if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err))
exitCode = 1
return
}
targetServerConfig := server.Config{
Port: targetWSPort,
Host: targetWSHost,
}
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err))
exitCode = 1
return
}
_, domainsClient, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer domainsHandler.Close()
logger.Info("Domains service gRPC client successfully connected to domains gRPC server " + domainsHandler.Secure())
clientsClientCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err))
exitCode = 1
return
}
clientsClient, clientsHandler, err := grpcclient.SetupClientsClient(ctx, clientsClientCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer clientsHandler.Close()
logger.Info("Clients service gRPC client successfully connected to clients gRPC server " + clientsHandler.Secure())
channelsClientCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&channelsClientCfg, env.Options{Prefix: envPrefixChannels}); err != nil {
logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration : %s", err))
exitCode = 1
return
}
channelsClient, channelsHandler, err := grpcclient.SetupChannelsClient(ctx, channelsClientCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer channelsHandler.Close()
logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure())
authnCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&authnCfg, env.Options{Prefix: envPrefixAuth}); err != nil {
logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err))
exitCode = 1
return
}
isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm)
if err != nil {
logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err))
exitCode = 1
return
}
var authn authn.Authentication
var authnClient grpcclient.Handler
switch {
case !isSymmetric:
authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authnCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL)
default:
authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authnCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
}
tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
if err != nil {
logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err))
exitCode = 1
return
}
defer func() {
if err := tp.Shutdown(ctx); err != nil {
logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err))
}
}()
tracer := tp.Tracer(svcName)
nps, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger)
if err != nil {
logger.Error(fmt.Sprintf("Failed to connect to message broker: %s", err))
exitCode = 1
return
}
defer nps.Close()
nps = brokerstracing.NewPubSub(targetServerConfig, tracer, nps)
nps, err = msgevents.NewPubSubMiddleware(ctx, nps, cfg.ESURL)
if err != nil {
logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err))
exitCode = 1
return
}
resolver := messaging.NewTopicResolver(channelsClient, domainsClient)
cacheConfig := messaging.CacheConfig{}
if err := env.ParseWithOptions(&cacheConfig, env.Options{Prefix: envPrefixCache}); err != nil {
logger.Error(fmt.Sprintf("failed to load cache configuration : %s", err))
exitCode = 1
return
}
parser, err := messaging.NewTopicParser(cacheConfig, channelsClient, domainsClient)
if err != nil {
logger.Error(fmt.Sprintf("failed to create topic parser: %s", err))
exitCode = 1
return
}
svc := newService(clientsClient, channelsClient, authn, nps, logger, tracer)
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
go chc.CallHome(ctx)
}
g.Go(func() error {
return hs.Start()
})
g.Go(func() error {
handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient, parser)
return proxyWS(ctx, httpServerConfig, targetServerConfig, logger, handler)
})
g.Go(func() error {
return server.StopSignalHandler(ctx, cancel, logger, svcName, hs)
})
if err := g.Wait(); err != nil {
logger.Error(fmt.Sprintf("WS adapter service terminated: %s", err))
}
}
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn authn.Authentication, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
svc := ws.New(clientsClient, channels, authn, nps)
svc = middleware.NewTracing(tracer, svc)
svc = middleware.NewLogging(svc, logger)
counter, latency := prometheus.MakeMetrics("ws_adapter", "api")
svc = middleware.NewMetrics(svc, counter, latency)
return svc
}
func proxyWS(ctx context.Context, hostConfig, targetConfig server.Config, logger *slog.Logger, handler session.Handler) error {
config := mgate.Config{
Host: hostConfig.Host,
Port: hostConfig.Port,
TargetProtocol: targetWSProtocol,
TargetHost: targetWSHost,
TargetPort: targetWSPort,
}
wp, err := http.NewProxy(config, handler, logger, []string{}, []string{"/health", "/metrics"})
if err != nil {
return err
}
errCh := make(chan error)
go func() {
errCh <- wp.Listen(ctx)
}()
select {
case <-ctx.Done():
logger.Info(fmt.Sprintf("ws-adapter service shutdown at %s:%s", hostConfig.Host, hostConfig.Port))
return nil
case err := <-errCh:
return err
}
}
-11
View File
@@ -437,17 +437,6 @@ SMQ_COAP_ADAPTER_CACHE_MAX_COST=1048576
SMQ_COAP_ADAPTER_CACHE_BUFFER_ITEMS=64
SMQ_COAP_ADAPTER_INSTANCE_ID=
### WS
SMQ_WS_ADAPTER_LOG_LEVEL=debug
SMQ_WS_ADAPTER_HTTP_HOST=ws-adapter
SMQ_WS_ADAPTER_HTTP_PORT=8186
SMQ_WS_ADAPTER_HTTP_SERVER_CERT=
SMQ_WS_ADAPTER_HTTP_SERVER_KEY=
SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS=200000
SMQ_WS_ADAPTER_CACHE_MAX_COST=1048576
SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS=64
SMQ_WS_ADAPTER_INSTANCE_ID=
## Addons Services
# Certs
AM_CERTS_LOG_LEVEL=debug
-114
View File
@@ -432,7 +432,6 @@ services:
- users
- mqtt-adapter
- http-adapter
- ws-adapter
- coap-adapter
ulimits:
nofile:
@@ -1538,119 +1537,6 @@ services:
bind:
create_host_path: true
ws-adapter:
image: docker.io/supermq/ws:${SMQ_RELEASE_TAG}
container_name: supermq-ws
depends_on:
- clients
- nats
restart: on-failure
environment:
SMQ_WS_ADAPTER_LOG_LEVEL: ${SMQ_WS_ADAPTER_LOG_LEVEL}
SMQ_WS_ADAPTER_HTTP_HOST: ${SMQ_WS_ADAPTER_HTTP_HOST}
SMQ_WS_ADAPTER_HTTP_PORT: ${SMQ_WS_ADAPTER_HTTP_PORT}
SMQ_WS_ADAPTER_HTTP_SERVER_CERT: ${SMQ_WS_ADAPTER_HTTP_SERVER_CERT}
SMQ_WS_ADAPTER_HTTP_SERVER_KEY: ${SMQ_WS_ADAPTER_HTTP_SERVER_KEY}
SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS: ${SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS}
SMQ_WS_ADAPTER_CACHE_MAX_COST: ${SMQ_WS_ADAPTER_CACHE_MAX_COST}
SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS: ${SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS}
SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL}
SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT}
SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt}
SMQ_CLIENTS_GRPC_CLIENT_KEY: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key}
SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt}
SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL}
SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT}
SMQ_CHANNELS_GRPC_CLIENT_CERT: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt}
SMQ_CHANNELS_GRPC_CLIENT_KEY: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key}
SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt}
SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL}
SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT}
SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt}
SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key}
SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt}
SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL}
SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT}
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM}
SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL}
SMQ_JAEGER_URL: ${SMQ_JAEGER_URL}
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY}
SMQ_WS_ADAPTER_INSTANCE_ID: ${SMQ_WS_ADAPTER_INSTANCE_ID}
SMQ_ES_URL: ${SMQ_ES_URL}
ports:
- ${SMQ_WS_ADAPTER_HTTP_PORT}:${SMQ_WS_ADAPTER_HTTP_PORT}
networks:
- supermq-base-net
volumes:
# Clients gRPC mTLS client certificates
- type: bind
source: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert}
target: /clients-grpc-client${SMQ_CLIENTS_GRPC_CLIENT_CERT:+.crt}
bind:
create_host_path: true
- type: bind
source: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key}
target: /clients-grpc-client${SMQ_CLIENTS_GRPC_CLIENT_KEY:+.key}
bind:
create_host_path: true
- type: bind
source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca}
target: /clients-grpc-server-ca${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+.crt}
bind:
create_host_path: true
# Channels gRPC mTLS client certificates
- type: bind
source: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert}
target: /channels-grpc-client${SMQ_CHANNELS_GRPC_CLIENT_CERT:+.crt}
bind:
create_host_path: true
- type: bind
source: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key}
target: /channels-grpc-client${SMQ_CHANNELS_GRPC_CLIENT_KEY:+.key}
bind:
create_host_path: true
- type: bind
source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca}
target: /channels-grpc-server-ca${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+.crt}
bind:
create_host_path: true
# Auth gRPC mTLS client certificates
- type: bind
source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert}
target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_CERT:+.crt}
bind:
create_host_path: true
- type: bind
source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key}
target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_KEY:+.key}
bind:
create_host_path: true
- type: bind
source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca}
target: /auth-grpc-server-ca${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+.crt}
bind:
create_host_path: true
# Domains gRPC mTLS client certificates
- type: bind
source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert}
target: /domains-grpc-client${SMQ_DOMAINS_GRPC_CLIENT_CERT:+.crt}
bind:
create_host_path: true
- type: bind
source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key}
target: /domains-grpc-client${SMQ_DOMAINS_GRPC_CLIENT_KEY:+.key}
bind:
create_host_path: true
- type: bind
source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca}
target: /domains-grpc-server-ca${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+.crt}
bind:
create_host_path: true
rabbitmq:
image: docker.io/rabbitmq:4.1.4-management-alpine
container_name: supermq-rabbitmq
+1 -2
View File
@@ -22,7 +22,6 @@ envsubst '
${SMQ_CHANNELS_HTTP_PORT}
${SMQ_HTTP_ADAPTER_PORT}
${SMQ_NGINX_MQTT_PORT}
${SMQ_NGINX_MQTTS_PORT}
${SMQ_WS_ADAPTER_HTTP_PORT}' < /etc/nginx/nginx.conf.template > /etc/nginx/nginx.conf
${SMQ_NGINX_MQTTS_PORT}' < /etc/nginx/nginx.conf.template > /etc/nginx/nginx.conf
exec nginx -g "daemon off;"
+1 -1
View File
@@ -131,7 +131,7 @@ http {
location /ws/ {
include snippets/proxy-headers.conf;
include snippets/ws-upgrade.conf;
proxy_pass http://ws-adapter:${SMQ_WS_ADAPTER_HTTP_PORT}/;
proxy_pass http://http-adapter:${SMQ_HTTP_ADAPTER_PORT}/;
}
}
}
+1 -1
View File
@@ -144,7 +144,7 @@ http {
include snippets/verify-ssl-client.conf;
include snippets/proxy-headers.conf;
include snippets/ws-upgrade.conf;
proxy_pass http://ws-adapter:${SMQ_WS_ADAPTER_HTTP_PORT}/;
proxy_pass http://http-adapter:${SMQ_HTTP_ADAPTER_PORT}/;
}
}
}
+39 -39
View File
@@ -1,6 +1,6 @@
# HTTP Adapter
The HTTP Adapter exposes HTTP endpoints for publishing messages into SuperMQ channels. It authenticates clients via tokens or Basic auth, resolves domains/channels over gRPC, and forwards payloads to the message broker.
The HTTP Adapter exposes HTTP endpoints for publishing messages and WebSocket capabilities for publishing and subscribing to messages from SuperMQ channels. It authenticates clients via tokens or Basic auth, resolves domains/channels over gRPC, and forwards payloads to the message broker.
For more on SuperMQ, see the [official documentation][doc].
@@ -8,44 +8,44 @@ For more on SuperMQ, see the [official documentation][doc].
Environment variables (unset values fall back to defaults):
| Variable | Description | Default |
| -------------------------------------- | ------------------------------------------------------------------------------------------ | ------------------------------------- |
| `SMQ_HTTP_ADAPTER_LOG_LEVEL` | Log level (debug, info, warn, error) | debug |
| `SMQ_HTTP_ADAPTER_HOST` | HTTP Adapter host | http-adapter |
| `SMQ_HTTP_ADAPTER_PORT` | HTTP Adapter port | 8008 |
| `SMQ_HTTP_ADAPTER_SERVER_CERT` | Path to PEM-encoded server certificate (enables TLS) | "" |
| `SMQ_HTTP_ADAPTER_SERVER_KEY` | Path to PEM-encoded server key | "" |
| `SMQ_HTTP_ADAPTER_SERVER_CA_CERTS` | Trusted CA bundle for HTTPS server | "" |
| `SMQ_HTTP_ADAPTER_CLIENT_CA_CERTS` | Client CA bundle to require mTLS on HTTPS server | "" |
| `SMQ_HTTP_ADAPTER_CACHE_NUM_COUNTERS` | Cache counters for topic parsing | 200000 |
| `SMQ_HTTP_ADAPTER_CACHE_MAX_COST` | Maximum cache size (bytes) | 1048576 |
| `SMQ_HTTP_ADAPTER_CACHE_BUFFER_ITEMS` | Cache buffer items | 64 |
| `SMQ_MESSAGE_BROKER_URL` | Message broker URL (publishing target) | nats://nats:4222 |
| `SMQ_ES_URL` | Event store URL (publishing middleware) | nats://nats:4222 |
| `SMQ_JAEGER_URL` | Jaeger tracing endpoint | <http://jaeger:4318/v1/traces> |
| `SMQ_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 |
| `SMQ_SEND_TELEMETRY` | Send telemetry to SuperMQ call-home server | true |
| `SMQ_HTTP_ADAPTER_INSTANCE_ID` | Service instance ID (auto-generated when empty) | "" |
| `SMQ_CLIENTS_GRPC_URL` | Clients service gRPC URL | clients:7006 |
| `SMQ_CLIENTS_GRPC_TIMEOUT` | Clients gRPC request timeout | 300s |
| `SMQ_CLIENTS_GRPC_CLIENT_CERT` | Clients gRPC client certificate | "" |
| `SMQ_CLIENTS_GRPC_CLIENT_KEY` | Clients gRPC client key | "" |
| `SMQ_CLIENTS_GRPC_SERVER_CA_CERTS` | Clients gRPC trusted CA bundle | "" |
| `SMQ_CHANNELS_GRPC_URL` | Channels service gRPC URL | channels:7005 |
| `SMQ_CHANNELS_GRPC_TIMEOUT` | Channels gRPC request timeout | 300s |
| `SMQ_CHANNELS_GRPC_CLIENT_CERT` | Channels gRPC client certificate | "" |
| `SMQ_CHANNELS_GRPC_CLIENT_KEY` | Channels gRPC client key | "" |
| `SMQ_CHANNELS_GRPC_SERVER_CA_CERTS` | Channels gRPC trusted CA bundle | "" |
| `SMQ_DOMAINS_GRPC_URL` | Domains service gRPC URL | domains:7003 |
| `SMQ_DOMAINS_GRPC_TIMEOUT` | Domains gRPC request timeout | 300s |
| `SMQ_DOMAINS_GRPC_CLIENT_CERT` | Domains gRPC client certificate | "" |
| `SMQ_DOMAINS_GRPC_CLIENT_KEY` | Domains gRPC client key | "" |
| `SMQ_DOMAINS_GRPC_SERVER_CA_CERTS` | Domains gRPC trusted CA bundle | "" |
| `SMQ_AUTH_GRPC_URL` | Auth service gRPC URL | auth:7001 |
| `SMQ_AUTH_GRPC_TIMEOUT` | Auth service gRPC request timeout | 300s |
| `SMQ_AUTH_GRPC_CLIENT_CERT` | Auth gRPC client certificate | "" |
| `SMQ_AUTH_GRPC_CLIENT_KEY` | Auth gRPC client key | "" |
| `SMQ_AUTH_GRPC_SERVER_CA_CERTS` | Auth gRPC trusted CA bundle | "" |
| Variable | Description | Default |
| ------------------------------------- | ---------------------------------------------------- | ------------------------------ |
| `SMQ_HTTP_ADAPTER_LOG_LEVEL` | Log level (debug, info, warn, error) | debug |
| `SMQ_HTTP_ADAPTER_HOST` | HTTP Adapter host | http-adapter |
| `SMQ_HTTP_ADAPTER_PORT` | HTTP Adapter port | 8008 |
| `SMQ_HTTP_ADAPTER_SERVER_CERT` | Path to PEM-encoded server certificate (enables TLS) | "" |
| `SMQ_HTTP_ADAPTER_SERVER_KEY` | Path to PEM-encoded server key | "" |
| `SMQ_HTTP_ADAPTER_SERVER_CA_CERTS` | Trusted CA bundle for HTTPS server | "" |
| `SMQ_HTTP_ADAPTER_CLIENT_CA_CERTS` | Client CA bundle to require mTLS on HTTPS server | "" |
| `SMQ_HTTP_ADAPTER_CACHE_NUM_COUNTERS` | Cache counters for topic parsing | 200000 |
| `SMQ_HTTP_ADAPTER_CACHE_MAX_COST` | Maximum cache size (bytes) | 1048576 |
| `SMQ_HTTP_ADAPTER_CACHE_BUFFER_ITEMS` | Cache buffer items | 64 |
| `SMQ_MESSAGE_BROKER_URL` | Message broker URL (publishing target) | nats://nats:4222 |
| `SMQ_ES_URL` | Event store URL (publishing middleware) | nats://nats:4222 |
| `SMQ_JAEGER_URL` | Jaeger tracing endpoint | <http://jaeger:4318/v1/traces> |
| `SMQ_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 |
| `SMQ_SEND_TELEMETRY` | Send telemetry to SuperMQ call-home server | true |
| `SMQ_HTTP_ADAPTER_INSTANCE_ID` | Service instance ID (auto-generated when empty) | "" |
| `SMQ_CLIENTS_GRPC_URL` | Clients service gRPC URL | clients:7006 |
| `SMQ_CLIENTS_GRPC_TIMEOUT` | Clients gRPC request timeout | 300s |
| `SMQ_CLIENTS_GRPC_CLIENT_CERT` | Clients gRPC client certificate | "" |
| `SMQ_CLIENTS_GRPC_CLIENT_KEY` | Clients gRPC client key | "" |
| `SMQ_CLIENTS_GRPC_SERVER_CA_CERTS` | Clients gRPC trusted CA bundle | "" |
| `SMQ_CHANNELS_GRPC_URL` | Channels service gRPC URL | channels:7005 |
| `SMQ_CHANNELS_GRPC_TIMEOUT` | Channels gRPC request timeout | 300s |
| `SMQ_CHANNELS_GRPC_CLIENT_CERT` | Channels gRPC client certificate | "" |
| `SMQ_CHANNELS_GRPC_CLIENT_KEY` | Channels gRPC client key | "" |
| `SMQ_CHANNELS_GRPC_SERVER_CA_CERTS` | Channels gRPC trusted CA bundle | "" |
| `SMQ_DOMAINS_GRPC_URL` | Domains service gRPC URL | domains:7003 |
| `SMQ_DOMAINS_GRPC_TIMEOUT` | Domains gRPC request timeout | 300s |
| `SMQ_DOMAINS_GRPC_CLIENT_CERT` | Domains gRPC client certificate | "" |
| `SMQ_DOMAINS_GRPC_CLIENT_KEY` | Domains gRPC client key | "" |
| `SMQ_DOMAINS_GRPC_SERVER_CA_CERTS` | Domains gRPC trusted CA bundle | "" |
| `SMQ_AUTH_GRPC_URL` | Auth service gRPC URL | auth:7001 |
| `SMQ_AUTH_GRPC_TIMEOUT` | Auth service gRPC request timeout | 300s |
| `SMQ_AUTH_GRPC_CLIENT_CERT` | Auth gRPC client certificate | "" |
| `SMQ_AUTH_GRPC_CLIENT_KEY` | Auth gRPC client key | "" |
| `SMQ_AUTH_GRPC_SERVER_CA_CERTS` | Auth gRPC trusted CA bundle | "" |
## Deployment
+3 -3
View File
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package ws
package http
import (
"context"
@@ -47,8 +47,8 @@ type adapterService struct {
pubsub messaging.PubSub
}
// New instantiates the WS adapter implementation.
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, pubsub messaging.PubSub) Service {
// NewService instantiates the HTTP adapter implementation.
func NewService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, pubsub messaging.PubSub) Service {
return &adapterService{
clients: clients,
channels: channels,
+7 -7
View File
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package ws_test
package http_test
import (
"context"
@@ -16,6 +16,7 @@ import (
apiutil "github.com/absmach/supermq/api/http/util"
chmocks "github.com/absmach/supermq/channels/mocks"
climocks "github.com/absmach/supermq/clients/mocks"
smqhttp "github.com/absmach/supermq/http"
"github.com/absmach/supermq/internal/testsutil"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
@@ -25,7 +26,6 @@ import (
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/ws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
@@ -59,19 +59,19 @@ var (
invalidEncodedCreds = base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", invalidID, invalidKey)))
)
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *authnmocks.Authentication) {
func newService() (smqhttp.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *authnmocks.Authentication) {
pubsub := new(mocks.PubSub)
clients := new(climocks.ClientsServiceClient)
channels := new(chmocks.ChannelsServiceClient)
authn := new(authnmocks.Authentication)
return ws.New(clients, channels, authn, pubsub), pubsub, clients, channels, authn
return smqhttp.NewService(clients, channels, authn, pubsub), pubsub, clients, channels, authn
}
func TestSubscribe(t *testing.T) {
svc, pubsub, clients, channels, auth := newService()
c := ws.NewClient(slog.Default(), nil, sessionID)
c := smqhttp.NewClient(slog.Default(), nil, sessionID)
cases := []struct {
desc string
@@ -149,11 +149,11 @@ func TestSubscribe(t *testing.T) {
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
subErr: ws.ErrFailedSubscription,
subErr: smqhttp.ErrFailedSubscription,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: ws.ErrFailedSubscription,
err: smqhttp.ErrFailedSubscription,
},
{
desc: "subscribe to channel with invalid clientKey",
+86
View File
@@ -5,12 +5,49 @@ package api
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"log/slog"
"net/http"
"strings"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
smqhttp "github.com/absmach/supermq/http"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/messaging"
"github.com/go-kit/kit/endpoint"
)
func messageHandler(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if isWebSocketRequest(r) {
handleWebSocket(ctx, svc, resolver, logger, w, r)
return
}
if r.Method != http.MethodPost {
encodeError(ctx, w, errMethodNotAllowed)
return
}
req, err := decodePublishReq(ctx, r)
if err != nil {
encodeError(ctx, w, err)
return
}
_, err = sendMessageEndpoint()(ctx, req)
if err != nil {
encodeError(ctx, w, err)
return
}
err = api.EncodeResponse(ctx, w, publishMessageRes{})
if err != nil {
encodeError(ctx, w, err)
}
}
}
func sendMessageEndpoint() endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(publishReq)
@@ -32,3 +69,52 @@ func healthCheckEndpoint() endpoint.Endpoint {
return healthCheckRes{}, nil
}
}
func handleWebSocket(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger, w http.ResponseWriter, r *http.Request) {
req, err := decodeWSReq(r, resolver, logger)
if err != nil {
encodeError(ctx, w, err)
return
}
sessionID, err := generateSessionID()
if err != nil {
logger.Warn(fmt.Sprintf("Failed to generate session id: %s", err.Error()))
http.Error(w, "", http.StatusInternalServerError)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error()))
return
}
client := smqhttp.NewClient(logger, conn, sessionID)
client.SetCloseHandler(func(code int, text string) error {
return svc.Unsubscribe(ctx, sessionID, req.domainID, req.channelID, req.subtopic, messaging.MessageType)
})
go client.Start(ctx)
if err := svc.Subscribe(ctx, sessionID, req.username, req.password, req.domainID, req.channelID, req.subtopic, messaging.MessageType, client); err != nil {
conn.Close()
return
}
logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channelID))
}
func isWebSocketRequest(r *http.Request) bool {
return strings.EqualFold(r.Header.Get(connHeaderKey), connHeaderVal) &&
strings.EqualFold(r.Header.Get(upgradeHeaderKey), upgradeHeaderVal)
}
func generateSessionID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", errors.Wrap(errGenSessionID, err)
}
return hex.EncodeToString(b), nil
}
+176 -115
View File
@@ -4,6 +4,7 @@
package api_test
import (
"context"
"fmt"
"io"
"net"
@@ -35,13 +36,26 @@ import (
"github.com/absmach/supermq/pkg/messaging"
pubsub "github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/absmach/supermq/pkg/policies"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
invalidValue = "invalid"
clientKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529"
wsProtocol = "ws"
invalidKey = "invalid-key"
validToken = "valid-token"
invalidToken = "invalid-token"
ctSenmlJSON = "application/senml+json"
ctSenmlCBOR = "application/senml+cbor"
ctJSON = "application/json"
msgJSON = `{"field1":"val1","field2":"val2"}`
msgCBOR = `81A3616E6763757272656E746174206176FB3FF999999999999A`
msg = `[{"n":"current","t":-1,"v":1.6}]`
)
var (
@@ -51,18 +65,22 @@ var (
userID = testsutil.GenerateUUID(&testing.T{})
)
func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub, error) {
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, pubsub *pubsub.PubSub) server.Service {
return server.NewService(clients, channels, authn, pubsub)
}
func newHandler(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub, error) {
pub := new(pubsub.PubSub)
parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains)
if err != nil {
return nil, nil, err
}
return server.NewHandler(pub, authn, clients, channels, parser, smqlog.NewMock()), pub, nil
return server.NewHandler(pub, smqlog.NewMock(), authn, clients, channels, parser), pub, nil
}
func newTargetHTTPServer() *httptest.Server {
mux := api.MakeHandler(smqlog.NewMock(), instanceID)
func newTargetHTTPServer(resolver messaging.TopicResolver, svc server.Service) *httptest.Server {
mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), instanceID)
return httptest.NewServer(mux)
}
@@ -123,23 +141,14 @@ func TestPublish(t *testing.T) {
authn := new(authnMocks.Authentication)
channels := new(chmocks.ChannelsServiceClient)
domains := new(dmocks.DomainsServiceClient)
ctSenmlJSON := "application/senml+json"
ctSenmlCBOR := "application/senml+cbor"
ctJSON := "application/json"
clientKey := "client_key"
invalidKey := invalidValue
validToken := "token"
invalidToken := "invalid_token"
msg := `[{"n":"current","t":-1,"v":1.6}]`
msgJSON := `{"field1":"val1","field2":"val2"}`
msgCBOR := `81A3616E6763757272656E746174206176FB3FF999999999999A`
svc, pub, err := newService(authn, clients, channels, domains)
assert.Nil(t, err, fmt.Sprintf("failed to create service with err: %v", err))
target := newTargetHTTPServer()
resolver := messaging.NewTopicResolver(channels, domains)
handler, pubsub, err := newHandler(authn, clients, channels, domains)
assert.Nil(t, err, fmt.Sprintf("failed to create handler with err: %v", err))
svc := newService(clients, channels, authn, pubsub)
target := newTargetHTTPServer(resolver, svc)
defer target.Close()
ts, err := newProxyHTPPServer(svc, target)
assert.Nil(t, err, fmt.Sprintf("failed to create proxy server with err: %v", err))
ts, err := newProxyHTPPServer(handler, target)
require.Nil(t, err)
defer ts.Close()
cases := []struct {
@@ -323,7 +332,7 @@ func TestPublish(t *testing.T) {
ClientType: tc.clientType,
Type: uint32(connections.Publish),
}).Return(tc.authzRes, tc.authzErr)
svcCall := pub.On("Publish", mock.Anything, messaging.EncodeTopicSuffix(tc.domainID, tc.chanID, ""), mock.Anything).Return(nil)
svcCall := pubsub.On("Publish", mock.Anything, messaging.EncodeTopicSuffix(tc.domainID, tc.chanID, ""), mock.Anything).Return(nil)
req := testRequest{
client: ts.Client(),
method: http.MethodPost,
@@ -346,123 +355,175 @@ func TestPublish(t *testing.T) {
}
}
func TestHealthCheck(t *testing.T) {
func TestHandshake(t *testing.T) {
clients := new(climocks.ClientsServiceClient)
authn := new(authnMocks.Authentication)
channels := new(chmocks.ChannelsServiceClient)
authn := new(authnMocks.Authentication)
domains := new(dmocks.DomainsServiceClient)
clientKey := "client_key"
invalidKey := invalidValue
validToken := "token"
invalidToken := "invalid_token"
svc, _, err := newService(authn, clients, channels, domains)
assert.Nil(t, err, fmt.Sprintf("failed to create service with err: %v", err))
target := newTargetHTTPServer()
resolver := messaging.NewTopicResolver(channels, domains)
handler, pubsub, err := newHandler(authn, clients, channels, domains)
assert.Nil(t, err, fmt.Sprintf("failed to create handler with err: %v", err))
svc := newService(clients, channels, authn, pubsub)
target := newTargetHTTPServer(resolver, svc)
defer target.Close()
ts, err := newProxyHTPPServer(svc, target)
assert.Nil(t, err, fmt.Sprintf("failed to create proxy server with err: %v", err))
ts, err := newProxyHTPPServer(handler, target)
require.Nil(t, err)
defer ts.Close()
msg := []byte(`[{"n":"current","t":-1,"v":1.6}]`)
pubsub.On("Subscribe", mock.Anything, mock.Anything).Return(nil)
pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil)
pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil)
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil)
authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil)
channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil)
cases := []struct {
desc string
domainID string
clientID string
clientType string
key string
status int
basicAuth bool
bearerToken bool
authnErr error
authnRes *grpcClientsV1.AuthnRes
authnRes1 smqauthn.Session
desc string
domainID string
chanID string
subtopic string
header bool
clientKey string
status int
err error
msg []byte
}{
{
desc: "health check successfully",
domainID: domainID,
key: clientKey,
status: http.StatusOK,
authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
},
{
desc: "health check with basic auth",
desc: "connect and send message",
domainID: domainID,
key: clientKey,
basicAuth: true,
status: http.StatusOK,
authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
chanID: chanID,
subtopic: "",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "health check with invalid key",
domainID: domainID,
key: invalidKey,
status: http.StatusUnauthorized,
authnRes: &grpcClientsV1.AuthnRes{Authenticated: false},
},
{
desc: "health check with invalid basic auth",
desc: "connect and send message with clientKey as query parameter",
domainID: domainID,
key: invalidKey,
basicAuth: true,
chanID: chanID,
subtopic: "",
header: false,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message that cannot be published",
domainID: domainID,
chanID: chanID,
subtopic: "",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: []byte{},
},
{
desc: "connect and send message to subtopic",
domainID: domainID,
chanID: chanID,
subtopic: "subtopic",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message to nested subtopic",
domainID: domainID,
chanID: chanID,
subtopic: "subtopic/nested",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message to all subtopics",
domainID: domainID,
chanID: chanID,
subtopic: ">",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect to empty channel",
domainID: domainID,
chanID: "",
subtopic: "",
header: true,
clientKey: clientKey,
status: http.StatusUnauthorized,
authnRes: &grpcClientsV1.AuthnRes{Authenticated: false},
msg: []byte{},
},
{
desc: "health check with valid bearer token",
domainID: domainID,
key: validToken,
bearerToken: true,
status: http.StatusOK,
authnRes1: smqauthn.Session{UserID: userID},
desc: "connect with empty clientKey",
domainID: domainID,
chanID: chanID,
subtopic: "",
header: true,
clientKey: "",
status: http.StatusBadRequest,
msg: []byte{},
},
{
desc: "health check with invalid bearer token",
domainID: domainID,
key: invalidToken,
bearerToken: true,
status: http.StatusUnauthorized,
authnRes1: smqauthn.Session{},
authnErr: svcerr.ErrAuthentication,
},
{
desc: "health check with empty key",
domainID: domainID,
key: "",
status: http.StatusBadRequest,
},
{
desc: "health check with empty domain ID",
domainID: "",
key: clientKey,
status: http.StatusBadRequest,
},
{
desc: "health check with invalid domain ID",
domainID: invalidValue,
key: clientKey,
status: http.StatusUnauthorized,
authnRes: &grpcClientsV1.AuthnRes{},
desc: "connect and send message to subtopic with invalid name",
domainID: domainID,
chanID: chanID,
subtopic: "sub/a*b/topic",
header: true,
clientKey: clientKey,
status: http.StatusUnauthorized,
msg: msg,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.key)}).Return(tc.authnRes, tc.authnErr)
authCall := authn.On("Authenticate", mock.Anything, tc.key).Return(tc.authnRes1, tc.authnErr)
domainsCall := domains.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil)
req := testRequest{
client: ts.Client(),
method: http.MethodPost,
url: fmt.Sprintf("%s/hc/%s", ts.URL, tc.domainID),
token: tc.key,
basicAuth: tc.basicAuth,
bearerToken: tc.bearerToken,
conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.clientKey, tc.header)
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode))
if tc.status == http.StatusSwitchingProtocols {
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err))
err = conn.WriteMessage(websocket.TextMessage, tc.msg)
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err))
}
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
clientsCall.Unset()
authCall.Unset()
domainsCall.Unset()
})
}
}
func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) {
u, _ := url.Parse(tsURL)
u.Scheme = wsProtocol
if chanID == "0" || chanID == "" {
if header {
return fmt.Sprintf("%s/m/%s/c/%s", u, domainID, chanID), fmt.Errorf("invalid channel id")
}
return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, clientKey), fmt.Errorf("invalid channel id")
}
subtopicPart := ""
if subtopic != "" {
subtopicPart = fmt.Sprintf("/%s", subtopic)
}
if header {
return fmt.Sprintf("%s/m/%s/c/%s%s", u, domainID, chanID, subtopicPart), nil
}
return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, clientKey), nil
}
func handshake(tsURL, domainID, chanID, subtopic, clientKey string, addHeader bool) (*websocket.Conn, *http.Response, error) {
header := http.Header{}
if addHeader {
header.Add("Authorization", clientKey)
}
turl, _ := makeURL(tsURL, domainID, chanID, subtopic, clientKey, addHeader)
conn, res, errRet := websocket.DefaultDialer.Dial(turl, header)
return conn, res, errRet
}
+8
View File
@@ -24,6 +24,14 @@ func (req publishReq) validate() error {
return nil
}
type connReq struct {
username string
password string
channelID string
domainID string
subtopic string
}
type healthCheckReq struct {
domain string
token string
+93 -19
View File
@@ -5,6 +5,7 @@ package api
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
@@ -12,41 +13,52 @@ import (
"github.com/absmach/supermq"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
smqhttp "github.com/absmach/supermq/http"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/messaging"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
const (
ctSenmlJSON = "application/senml+json"
ctSenmlCBOR = "application/senml+cbor"
contentType = "application/json"
authzHeaderKey = "Authorization"
ctSenmlJSON = "application/senml+json"
ctSenmlCBOR = "application/senml+cbor"
contentType = "application/json"
authzHeaderKey = "Authorization"
authzQueryKey = "authorization"
connHeaderKey = "Connection"
connHeaderVal = "upgrade"
upgradeHeaderKey = "Upgrade"
upgradeHeaderVal = "websocket"
readwriteBufferSize = 1024
)
var (
upgrader = websocket.Upgrader{
ReadBufferSize: readwriteBufferSize,
WriteBufferSize: readwriteBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
}
errUnauthorizedAccess = errors.New("missing or invalid credentials provided")
errMalformedSubtopic = errors.New("malformed subtopic")
errGenSessionID = errors.New("failed to generate session id")
errMethodNotAllowed = errors.New("method not allowed")
)
// MakeHandler returns a HTTP handler for API endpoints.
func MakeHandler(logger *slog.Logger, instanceID string) http.Handler {
func MakeHandler(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
r := chi.NewRouter()
r.Post("/m/{domain}/c/{channel}", otelhttp.NewHandler(kithttp.NewServer(
sendMessageEndpoint(),
decodeRequest,
api.EncodeResponse,
opts...,
), "publish").ServeHTTP)
r.Post("/m/{domain}/c/{channel}/*", otelhttp.NewHandler(kithttp.NewServer(
sendMessageEndpoint(),
decodeRequest,
api.EncodeResponse,
opts...,
), "publish").ServeHTTP)
r.Handle("/m/{domain}/c/{channel}", messageHandler(ctx, svc, resolver, logger))
r.Handle("/m/{domain}/c/{channel}/*", messageHandler(ctx, svc, resolver, logger))
r.Post("/hc/{domain}", otelhttp.NewHandler(kithttp.NewServer(
healthCheckEndpoint(),
@@ -61,7 +73,7 @@ func MakeHandler(logger *slog.Logger, instanceID string) http.Handler {
return r
}
func decodeRequest(_ context.Context, r *http.Request) (any, error) {
func decodePublishReq(_ context.Context, r *http.Request) (any, error) {
ct := r.Header.Get("Content-Type")
if ct != ctSenmlJSON && ct != contentType && ct != ctSenmlCBOR {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
@@ -87,6 +99,48 @@ func decodeRequest(_ context.Context, r *http.Request) (any, error) {
return req, nil
}
func decodeWSReq(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) {
username, password, ok := r.BasicAuth()
if !ok {
switch {
case r.URL.Query().Get(authzQueryKey) != "":
password = r.URL.Query().Get(authzQueryKey)
case r.Header.Get(authzHeaderKey) != "":
password = r.Header.Get(authzHeaderKey)
default:
logger.Debug("Missing authorization key.")
return connReq{}, errUnauthorizedAccess
}
}
domain := chi.URLParam(r, "domain")
channel := chi.URLParam(r, "channel")
domainID, channelID, _, err := resolver.Resolve(r.Context(), domain, channel)
if err != nil {
return connReq{}, err
}
req := connReq{
username: username,
password: password,
channelID: channelID,
domainID: domainID,
}
subTopic := chi.URLParam(r, "*")
if subTopic != "" {
subTopic, err := messaging.ParseSubscribeSubtopic(subTopic)
if err != nil {
return connReq{}, err
}
req.subtopic = subTopic
}
return req, nil
}
func decodeHealthCheckRequest(_ context.Context, r *http.Request) (any, error) {
var req healthCheckReq
req.domain = chi.URLParam(r, "domain")
@@ -104,3 +158,23 @@ func decodeHealthCheckRequest(_ context.Context, r *http.Request) (any, error) {
return req, nil
}
func encodeError(ctx context.Context, w http.ResponseWriter, err error) {
switch err {
case smqhttp.ErrEmptyTopic:
w.WriteHeader(http.StatusBadRequest)
case errUnauthorizedAccess:
w.WriteHeader(http.StatusForbidden)
case errMalformedSubtopic, errors.ErrMalformedEntity:
w.WriteHeader(http.StatusBadRequest)
default:
api.EncodeError(ctx, err, w)
return
}
if errorVal, ok := err.(errors.Error); ok {
if err := json.NewEncoder(w).Encode(errorVal); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
}
}
+1 -1
View File
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package ws
package http
import (
"context"
+4 -4
View File
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package ws_test
package http_test
import (
"context"
@@ -14,7 +14,7 @@ import (
"testing"
"time"
"github.com/absmach/supermq/ws"
smqhttp "github.com/absmach/supermq/http"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)
@@ -23,7 +23,7 @@ const expectedCount = uint64(2)
var (
msgChan = make(chan []byte)
c *ws.Client
c *smqhttp.Client
count uint64
upgrader = websocket.Upgrader{
@@ -63,7 +63,7 @@ func TestHandle(t *testing.T) {
}
defer wsConn.Close()
c = ws.NewClient(slog.Default(), wsConn, "sessionID")
c = smqhttp.NewClient(slog.Default(), wsConn, "sessionID")
go c.Start(context.Background())
cases := []struct {
+149 -90
View File
@@ -31,65 +31,116 @@ const protocol = "http"
// Log message formats.
const (
publishedInfoFmt = "published with client_type %s client_id %s to the topic %s"
failedAuthnFmt = "failed to authenticate client_type %s for topic %s with error %s"
LogInfoSubscribed = "subscribed with client_id %s to topics %s"
LogInfoConnected = "connected with client_id %s"
LogInfoDisconnected = "disconnected client_id %s and username %s"
LogInfoPublished = "published with client_id %s to the topic %s"
)
// Error wrappers for MQTT errors.
var (
errClientNotInitialized = errors.New("client is not initialized")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
errFailedPublish = errors.New("failed to publish")
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
errMalformedTopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed topic"))
errMissingTopicPub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to publish due to missing topic"))
errInvalidAuthFormat = errors.New("invalid basic auth format")
errInvalidClientType = errors.New("invalid client type")
)
// Event implements events.Event interface.
type handler struct {
publisher messaging.Publisher
clients grpcClientsV1.ClientsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
parser messaging.TopicParser
authn smqauthn.Authentication
logger *slog.Logger
pubsub messaging.PubSub
clients grpcClientsV1.ClientsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
authn smqauthn.Authentication
logger *slog.Logger
parser messaging.TopicParser
}
// NewHandler creates new Handler entity.
func NewHandler(publisher messaging.Publisher, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, parser messaging.TopicParser, logger *slog.Logger) session.Handler {
func NewHandler(pubsub messaging.PubSub, logger *slog.Logger, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, parser messaging.TopicParser) session.Handler {
return &handler{
publisher: publisher,
authn: authn,
clients: clients,
channels: channels,
parser: parser,
logger: logger,
logger: logger,
pubsub: pubsub,
authn: authn,
clients: clients,
channels: channels,
parser: parser,
}
}
// AuthConnect is called on device connection,
// prior forwarding to the HTTP server.
// prior forwarding to the http server.
func (h *handler) AuthConnect(ctx context.Context) error {
s, ok := session.FromContext(ctx)
if !ok {
return errClientNotInitialized
}
if string(s.Password) == "" {
var tok string
switch {
case string(s.Password) == "":
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerKey))
case strings.HasPrefix(string(s.Password), apiutil.ClientPrefix):
tok = strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix)
default:
tok = string(s.Password)
}
h.logger.Info(fmt.Sprintf(LogInfoConnected, tok))
return nil
}
// AuthPublish is called on device publish,
// prior forwarding to the http server.
func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error {
if topic == nil {
return errMissingTopicPub
}
s, ok := session.FromContext(ctx)
if !ok {
return errClientNotInitialized
}
domainID, channelID, _, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, true)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
clientID, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Publish, topicType)
if err != nil {
fmt.Println("AuthPublish authAccess error:", err)
return err
}
if s.Username == "" {
s.Username = clientID
}
return nil
}
// AuthPublish is not used in HTTP service.
func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error {
return nil
}
// AuthSubscribe is not used in HTTP service.
// AuthSubscribe is called on device publish,
// prior forwarding to the MQTT broker.
func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
s, ok := session.FromContext(ctx)
if !ok {
return errClientNotInitialized
}
if topics == nil || *topics == nil {
return errMissingTopicSub
}
for _, topic := range *topics {
domainID, channelID, _, topicType, err := h.parser.ParseSubscribeTopic(ctx, topic, true)
if err != nil {
return err
}
if _, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Subscribe, topicType); err != nil {
return err
}
}
return nil
}
@@ -100,54 +151,19 @@ func (h *handler) Connect(ctx context.Context) error {
// Publish - after client successfully published.
func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) error {
if topic == nil {
return errMissingTopicPub
}
topic = &strings.Split(*topic, "?")[0]
s, ok := session.FromContext(ctx)
if !ok {
return errors.Wrap(errFailedPublish, errClientNotInitialized)
return errClientNotInitialized
}
if len(*payload) == 0 {
h.logger.Warn("Empty payload, not publishing to broker", slog.String("client_id", s.Username))
return nil
}
domainID, channelID, subtopic, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, true)
if err != nil {
return errors.Wrap(errMalformedTopic, err)
}
var token, clientType string
pass := string(s.Password)
switch {
case s.Username != "" && pass != "":
token = smqauthn.AuthPack(smqauthn.BasicAuth, s.Username, pass)
clientType = policies.ClientType
case strings.HasPrefix(pass, apiutil.BasicAuthPrefix):
username, password, err := decodeAuth(strings.TrimPrefix(pass, apiutil.BasicAuthPrefix))
if err != nil {
h.logger.Warn(fmt.Sprintf(failedAuthnFmt, policies.ClientType, *topic, err))
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
}
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
case strings.HasPrefix(pass, apiutil.ClientPrefix):
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(pass, apiutil.ClientPrefix))
clientType = policies.ClientType
case strings.HasPrefix(pass, apiutil.BearerPrefix):
token = strings.TrimPrefix(pass, apiutil.BearerPrefix)
clientType = policies.UserType
default:
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
id, err := h.authenticate(ctx, clientType, token)
if err != nil {
h.logger.Warn(fmt.Sprintf(failedAuthnFmt, clientType, *topic, err))
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
}
// Health topics are not published to message broker.
if topicType == messaging.HealthType {
h.logger.Info(fmt.Sprintf(publishedInfoFmt, clientType, id, *topic))
return nil
return errors.Wrap(errFailedPublish, err)
}
msg := messaging.Message{
@@ -155,50 +171,93 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
Domain: domainID,
Channel: channelID,
Subtopic: subtopic,
Publisher: id,
Payload: *payload,
Publisher: s.Username,
Created: time.Now().UnixNano(),
}
ar := &grpcChannelsV1.AuthzReq{
DomainId: domainID,
ClientId: id,
ClientType: clientType,
ChannelId: msg.Channel,
Type: uint32(connections.Publish),
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
}
if !res.GetAuthorized() {
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthorization)
// Health check topic messages do not get published to message broker.
if topicType == messaging.MessageType {
if err := h.pubsub.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil {
return mgate.NewHTTPProxyError(http.StatusInternalServerError, errors.Wrap(errFailedPublishToMsgBroker, err))
}
}
if err := h.publisher.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil {
return errors.Wrap(errFailedPublishToMsgBroker, err)
}
h.logger.Info(fmt.Sprintf(publishedInfoFmt, clientType, id, *topic))
h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic))
return nil
}
// Subscribe - not used for HTTP.
// Subscribe - after client successfully subscribed.
func (h *handler) Subscribe(ctx context.Context, topics *[]string) error {
s, ok := session.FromContext(ctx)
if !ok {
return errClientNotInitialized
}
h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ",")))
return nil
}
// Unsubscribe - not used for HTTP.
// Unsubscribe - after client unsubscribed.
func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
return nil
}
// Disconnect - not used for HTTP.
// Disconnect - connection with broker or client lost.
func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func (h *handler) authAccess(ctx context.Context, username, password, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
var token, clientType string
var err error
switch {
case strings.HasPrefix(password, apiutil.BearerPrefix):
token = strings.TrimPrefix(password, apiutil.BearerPrefix)
clientType = policies.UserType
case username != "" && password != "":
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
case strings.HasPrefix(password, apiutil.BasicAuthPrefix):
username, password, err := decodeAuth(strings.TrimPrefix(password, apiutil.BasicAuthPrefix))
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
default:
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(password, apiutil.ClientPrefix))
clientType = policies.ClientType
}
id, err := h.authenticate(ctx, clientType, token)
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
// Health check topics do not require channel authorization.
if topicType == messaging.HealthType {
return id, nil
}
ar := &grpcChannelsV1.AuthzReq{
Type: uint32(msgType),
ClientId: id,
ClientType: clientType,
ChannelId: chanID,
DomainId: domainID,
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
if !res.GetAuthorized() {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
return id, nil
}
func (h *handler) authenticate(ctx context.Context, authType, token string) (string, error) {
switch authType {
case policies.UserType:
@@ -210,7 +269,7 @@ func (h *handler) authenticate(ctx context.Context, authType, token string) (str
case policies.ClientType:
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token})
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
return "", err
}
if !authnRes.Authenticated {
return "", svcerr.ErrAuthentication
+571 -297
View File
@@ -11,7 +11,7 @@ import (
"strings"
"testing"
mghttp "github.com/absmach/mgate/pkg/http"
mgate "github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
@@ -19,49 +19,41 @@ import (
chmocks "github.com/absmach/supermq/channels/mocks"
clmocks "github.com/absmach/supermq/clients/mocks"
dmocks "github.com/absmach/supermq/domains/mocks"
mhttp "github.com/absmach/supermq/http"
"github.com/absmach/supermq/internal/testsutil"
smqhttp "github.com/absmach/supermq/http"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/absmach/supermq/pkg/policies"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
const (
clientID = "513d02d2-16c1-4f23-98be-9e12f8fee898"
clientKey = "password"
chanID = "123e4567-e89b-12d3-a456-000000000001"
invalidValue = "invalidValue"
)
var (
domainID = testsutil.GenerateUUID(&testing.T{})
invalidValue = "invalid"
topicMsg = "/m/%s/c/%s"
subtopicMsg = "/m/%s/c/%s/subtopic"
hcTopicFmt = "/hc/%s"
hcTopic = fmt.Sprintf(hcTopicFmt, domainID)
topic = fmt.Sprintf(topicMsg, domainID, chanID)
subtopic = fmt.Sprintf(subtopicMsg, domainID, chanID)
invalidTopic = invalidValue
hcTopicFmt = "/hc/%s"
hcTopic = fmt.Sprintf(hcTopicFmt, domainID)
invalidHCTopic = "/hc"
invalidTopic = invalidValue
topics = []string{topic}
payload = []byte("[{'n':'test-name', 'v': 1.2}]")
sessionClient = session.Session{
ID: clientID,
Password: []byte(clientKey),
}
validToken = "token"
invalidToken = "invalid_token"
validID = testsutil.GenerateUUID(&testing.T{})
errClientNotInitialized = errors.New("client is not initialized")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
errMalformedTopic = errors.New("malformed topic")
errMalformedSubtopic = errors.New("malformed subtopic")
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
invalidChannelIDTopic = "m/**/c"
validToken = "token"
errClientNotInitialized = errors.New("client is not initialized")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
)
var (
@@ -81,73 +73,21 @@ func newHandler(t *testing.T) session.Handler {
parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains)
assert.Nil(t, err, fmt.Sprintf("unexpected error while creating topic parser: %v", err))
return mhttp.NewHandler(publisher, authn, clients, channels, parser, logger)
return smqhttp.NewHandler(publisher, logger, authn, clients, channels, parser)
}
func TestAuthConnect(t *testing.T) {
func TestAuthPublish(t *testing.T) {
handler := newHandler(t)
cases := []struct {
desc string
session *session.Session
status int
err error
}{
{
desc: "connect with valid username and password",
err: nil,
session: &sessionClient,
},
{
desc: "connect without active session",
session: nil,
status: http.StatusUnauthorized,
err: errClientNotInitialized,
},
{
desc: "connect with empty key",
session: &session.Session{
ID: clientID,
Password: []byte(""),
},
status: http.StatusBadRequest,
err: errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerKey),
},
{
desc: "connect with client key",
session: &session.Session{
ID: clientID,
Password: []byte("Client " + clientKey),
},
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
err := handler.AuthConnect(ctx)
hpe, ok := err.(mghttp.HTTPProxyError)
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err))
}
})
}
}
func TestPublish(t *testing.T) {
handler := newHandler(t)
malformedSubtopics := topic + "/" + subtopic + "%"
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
unauthorizedKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
invalidClientKeySession := session.Session{
Password: []byte("Client " + invalidKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
@@ -166,299 +106,633 @@ func TestPublish(t *testing.T) {
encodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + creds),
}
cases := []struct {
invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue)))
invalidEncodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + invalidCreds),
}
hcClientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
tests := []struct {
desc string
topic *string
channelID string
username string
payload *[]byte
password string
session *session.Session
topic *string
payload *[]byte
authKey string
status int
clientType string
chanID string
domainID string
clientID string
authNToken string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
authZRes *grpcChannelsV1.AuthzRes
authZErr error
publishErr error
err error
}{
{
desc: "publish with key successfully",
topic: &topic,
payload: &payload,
password: clientKey,
desc: "publish with client key successfully",
session: &clientKeySession,
channelID: chanID,
topic: &topic,
authKey: clientKey,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
},
{
desc: "publish with token successfully",
topic: &topic,
payload: &payload,
password: validToken,
session: &tokenSession,
channelID: chanID,
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
},
{
desc: "publish with invalid token",
topic: &topic,
payload: &payload,
password: invalidToken,
session: &invalidTokenSession,
channelID: chanID,
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with key and subtopic successfully",
topic: &subtopic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
},
{
desc: "publish with empty topic",
topic: nil,
payload: &payload,
session: &clientKeySession,
channelID: chanID,
status: http.StatusBadRequest,
err: errMissingTopicPub,
},
{
desc: "publish with invalid session",
topic: &topic,
payload: &payload,
session: nil,
channelID: chanID,
status: http.StatusUnauthorized,
err: errClientNotInitialized,
},
{
desc: "publish with invalid topic",
topic: &invalidTopic,
status: http.StatusBadRequest,
password: clientKey,
session: &clientKeySession,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
err: errMalformedTopic,
},
{
desc: "publish with malformed subtopic",
topic: &malformedSubtopics,
status: http.StatusBadRequest,
password: clientKey,
session: &clientKeySession,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
err: errMalformedSubtopic,
},
{
desc: "publish with empty password",
topic: &topic,
payload: &payload,
session: &session.Session{
Password: []byte(""),
},
channelID: chanID,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with client key and failed to authenticate",
desc: "publish with invalid client key",
session: &invalidClientKeySession,
topic: &topic,
authKey: invalidKey,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
authNErr: nil,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with client key and failed to authenticate with error",
desc: "publish with nil session",
session: nil,
topic: &topic,
authKey: clientKey,
status: http.StatusInternalServerError,
err: errClientNotInitialized,
},
{
desc: "publish with empty topic",
session: &clientKeySession,
topic: nil,
authKey: clientKey,
status: http.StatusBadRequest,
err: errMissingTopicPub,
},
{
desc: "publish with unauthorized client key",
session: &unauthorizedKeySession,
topic: &topic,
authKey: clientKey,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
status: http.StatusUnauthorized,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
authNErr: svcerr.ErrAuthentication,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "publish with token and failed to authenticate",
topic: &topic,
payload: &payload,
password: validToken,
session: &tokenSession,
channelID: chanID,
status: http.StatusUnauthorized,
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
authNErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with basic auth",
topic: &topic,
payload: &payload,
username: clientID,
password: clientKey,
session: &basicAuthSession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with token successfully",
session: &tokenSession,
topic: &topic,
authKey: token,
payload: &payload,
status: http.StatusOK,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
},
{
desc: "publish with invalid token",
session: &invalidTokenSession,
topic: &topic,
authKey: invalidToken,
payload: &payload,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with unauthorized token",
session: &tokenSession,
topic: &topic,
authKey: token,
payload: &payload,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with basic auth successfully",
session: &basicAuthSession,
topic: &topic,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid basic auth",
session: &invalidBasicAuthSession,
topic: &topic,
payload: &payload,
username: clientID,
password: clientKey,
session: &invalidBasicAuthSession,
channelID: chanID,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{},
authNErr: svcerr.ErrAuthentication,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "publish with encoded credentials",
desc: "publish with b64 encoded credentials",
session: &encodedCredsSession,
topic: &topic,
payload: &payload,
username: clientID,
password: clientKey,
session: &encodedCredsSession,
channelID: chanID,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
},
{
desc: "publish with unauthorized client",
desc: "publish with invalid b64 encoded credentials",
session: &invalidEncodedCredsSession,
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: nil,
err: svcerr.ErrAuthorization,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with authorization error",
topic: &topic,
desc: "publish with health check topic successfully",
session: &hcClientKeySession,
topic: &hcTopic,
authKey: clientKey,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: "",
domainID: domainID,
clientID: userID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
status: http.StatusUnauthorized,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "publish with failed to publish",
topic: &topic,
desc: "publish with invalid health check topic",
session: &hcClientKeySession,
topic: &invalidHCTopic,
authKey: clientKey,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
publishErr: errFailedPublishToMsgBroker,
err: errFailedPublishToMsgBroker,
},
{
desc: "publish health check with token successfully",
topic: &hcTopic,
payload: &payload,
password: validToken,
session: &tokenSession,
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
authNErr: nil,
err: nil,
},
{
desc: "publish health check with invalid topic",
topic: &invalidHCTopic,
status: http.StatusBadRequest,
password: validToken,
session: &tokenSession,
err: errMalformedTopic,
status: http.StatusBadRequest,
err: messaging.ErrMalformedTopic,
clientType: policies.ClientType,
},
}
for _, tc := range cases {
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
var internalTopic string
if tc.topic != nil {
internalTopic = strings.TrimPrefix(strings.ReplaceAll(*tc.topic, "/", "."), ".m.")
tc.clientType = policies.ClientType
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", ctx, mock.Anything).Return(tc.authZRes, tc.authZErr)
repoCall := publisher.On("Publish", ctx, internalTopic, mock.Anything).Return(tc.publishErr)
err := handler.Publish(ctx, tc.topic, tc.payload)
hpe, ok := err.(mghttp.HTTPProxyError)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Publish),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
err := handler.AuthPublish(ctx, tc.topic, tc.payload)
hpe, ok := err.(mgate.HTTPProxyError)
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err))
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
authCall.Unset()
repoCall.Unset()
clientsCall.Unset()
channelsCall.Unset()
})
}
}
func TestAuthSubscribe(t *testing.T) {
handler := newHandler(t)
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
unauthorizedKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
invalidClientKeySession := session.Session{
Password: []byte("Client " + invalidKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
basicAuthSession := session.Session{
Username: clientID,
Password: []byte(clientKey),
}
invalidBasicAuthSession := session.Session{
Username: clientID,
Password: []byte(invalidValue),
}
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
encodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + creds),
}
invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue)))
invalidEncodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + invalidCreds),
}
hcClientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
tests := []struct {
desc string
session *session.Session
topics *[]string
authKey string
status int
clientType string
chanID string
domainID string
clientID string
authNToken string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
authZRes *grpcChannelsV1.AuthzRes
authZErr error
err error
}{
{
desc: "subscribe with client key successfully",
session: &clientKeySession,
topics: &topics,
authKey: clientKey,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid client key",
session: &invalidClientKeySession,
topics: &topics,
authKey: invalidKey,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with empty topics",
session: &clientKeySession,
topics: nil,
authKey: clientKey,
status: http.StatusBadRequest,
err: errMissingTopicSub,
},
{
desc: "subscribe with nil session",
session: nil,
topics: &topics,
authKey: clientKey,
status: http.StatusInternalServerError,
err: errClientNotInitialized,
},
{
desc: "subscribe with unauthorized client key",
session: &unauthorizedKeySession,
topics: &topics,
authKey: clientKey,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with token successfully",
session: &tokenSession,
topics: &topics,
authKey: token,
status: http.StatusOK,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid token",
session: &invalidTokenSession,
topics: &topics,
authKey: invalidToken,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with unauthorized token",
session: &tokenSession,
topics: &topics,
authKey: token,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with basic auth successfully",
session: &basicAuthSession,
topics: &topics,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid basic auth",
session: &invalidBasicAuthSession,
topics: &topics,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with b64 encoded credentials",
session: &encodedCredsSession,
topics: &topics,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid b64 encoded credentials",
session: &invalidEncodedCredsSession,
topics: &topics,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with health check topic successfully",
session: &hcClientKeySession,
topics: &[]string{hcTopic},
authKey: clientKey,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
err: nil,
},
{
desc: "subscribe with invalid health check topic",
session: &hcClientKeySession,
topics: &[]string{invalidHCTopic},
authKey: clientKey,
status: http.StatusBadRequest,
err: messaging.ErrMalformedTopic,
clientType: policies.ClientType,
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
tc.clientType = policies.ClientType
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Subscribe),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
err := handler.AuthSubscribe(ctx, tc.topics)
hpe, ok := err.(mgate.HTTPProxyError)
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
authCall.Unset()
clientsCall.Unset()
channelsCall.Unset()
})
}
}
func TestPublish(t *testing.T) {
handler := newHandler(t)
malformedSubtopics := topic + "/" + subtopic + "%"
wrongCharSubtopics := topic + "/" + subtopic + ">"
validSubtopic := topic + "/" + subtopic
cases := []struct {
desc string
session *session.Session
topic string
payload []byte
err error
}{
{
desc: "publish without active session",
session: nil,
topic: topic,
payload: payload,
err: errClientNotInitialized,
},
{
desc: "publish with invalid topic",
session: &sessionClient,
topic: invalidTopic,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with invalid channel ID",
session: &sessionClient,
topic: invalidChannelIDTopic,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with malformed subtopic",
session: &sessionClient,
topic: malformedSubtopics,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with subtopic containing wrong character",
session: &sessionClient,
topic: wrongCharSubtopics,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with subtopic",
session: &sessionClient,
topic: validSubtopic,
payload: payload,
},
{
desc: "publish without subtopic",
session: &sessionClient,
topic: topic,
payload: payload,
},
{
desc: "publish with health check topic",
session: &sessionClient,
topic: hcTopic,
payload: payload,
},
{
desc: "puvlish with invalid health check topic",
session: &sessionClient,
topic: invalidHCTopic,
payload: payload,
err: messaging.ErrMalformedTopic,
},
}
for _, tc := range cases {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := handler.Publish(ctx, &tc.topic, &tc.payload)
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
repoCall.Unset()
}
}
@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
// Package middleware provides logging, metrics and tracing middleware
// for SuperMQ WebSockets service.
// for SuperMQ HTTP service.
//
// For more details about tracing instrumentation for SuperMQ refer to the
// documentation at https://docs.supermq.absmach.eu/tracing/.
@@ -8,25 +8,25 @@ import (
"log/slog"
"time"
smqhttp "github.com/absmach/supermq/http"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/ws"
)
var _ ws.Service = (*loggingMiddleware)(nil)
var _ smqhttp.Service = (*loggingMiddleware)(nil)
type loggingMiddleware struct {
logger *slog.Logger
svc ws.Service
svc smqhttp.Service
}
// NewLogging adds logging facilities to the websocket service.
func NewLogging(svc ws.Service, logger *slog.Logger) ws.Service {
func NewLogging(svc smqhttp.Service, logger *slog.Logger) smqhttp.Service {
return &loggingMiddleware{logger, svc}
}
// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) (err error) {
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *smqhttp.Client) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
@@ -9,21 +9,21 @@ import (
"context"
"time"
smqhttp "github.com/absmach/supermq/http"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/ws"
"github.com/go-kit/kit/metrics"
)
var _ ws.Service = (*metricsMiddleware)(nil)
var _ smqhttp.Service = (*metricsMiddleware)(nil)
type metricsMiddleware struct {
counter metrics.Counter
latency metrics.Histogram
svc ws.Service
svc smqhttp.Service
}
// NewMetrics instruments adapter by tracking request count and latency.
func NewMetrics(svc ws.Service, counter metrics.Counter, latency metrics.Histogram) ws.Service {
func NewMetrics(svc smqhttp.Service, counter metrics.Counter, latency metrics.Histogram) smqhttp.Service {
return &metricsMiddleware{
counter: counter,
latency: latency,
@@ -32,7 +32,7 @@ func NewMetrics(svc ws.Service, counter metrics.Counter, latency metrics.Histogr
}
// Subscribe instruments Subscribe method with metrics.
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) error {
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *smqhttp.Client) error {
defer func(begin time.Time) {
mm.counter.With("method", "subscribe").Add(1)
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
@@ -6,12 +6,12 @@ package middleware
import (
"context"
smqhttp "github.com/absmach/supermq/http"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/ws"
"go.opentelemetry.io/otel/trace"
)
var _ ws.Service = (*tracingMiddleware)(nil)
var _ smqhttp.Service = (*tracingMiddleware)(nil)
const (
subscribeOP = "subscribe_op"
@@ -20,19 +20,19 @@ const (
type tracingMiddleware struct {
tracer trace.Tracer
svc ws.Service
svc smqhttp.Service
}
// NewTracing returns a new websocket service with tracing capabilities.
func NewTracing(tracer trace.Tracer, svc ws.Service) ws.Service {
func NewTracing(tracer trace.Tracer, svc smqhttp.Service) smqhttp.Service {
return &tracingMiddleware{
tracer: tracer,
svc: svc,
}
}
// Subscribe traces the "Subscribe" operation of the wrapped ws.Service.
func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, client *ws.Client) error {
// Subscribe traces the "Subscribe" operation of the wrapped smqhttp.Service.
func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, client *smqhttp.Client) error {
ctx, span := tm.tracer.Start(ctx, subscribeOP)
defer span.End()
+224
View File
@@ -0,0 +1,224 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"context"
"github.com/absmach/supermq/http"
"github.com/absmach/supermq/pkg/messaging"
mock "github.com/stretchr/testify/mock"
)
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
mock.TestingT
Cleanup(func())
}) *Service {
mock := &Service{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Service is an autogenerated mock type for the Service type
type Service struct {
mock.Mock
}
type Service_Expecter struct {
mock *mock.Mock
}
func (_m *Service) EXPECT() *Service_Expecter {
return &Service_Expecter{mock: &_m.Mock}
}
// Subscribe provides a mock function for the type Service
func (_mock *Service) Subscribe(ctx context.Context, sessionID string, username string, password string, domainID string, chanID string, subtopic string, topicType messaging.TopicType, client *http.Client) error {
ret := _mock.Called(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client)
if len(ret) == 0 {
panic("no return value specified for Subscribe")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, string, messaging.TopicType, *http.Client) error); ok {
r0 = returnFunc(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe'
type Service_Subscribe_Call struct {
*mock.Call
}
// Subscribe is a helper method to define mock.On call
// - ctx context.Context
// - sessionID string
// - username string
// - password string
// - domainID string
// - chanID string
// - subtopic string
// - topicType messaging.TopicType
// - client *http.Client
func (_e *Service_Expecter) Subscribe(ctx interface{}, sessionID interface{}, username interface{}, password interface{}, domainID interface{}, chanID interface{}, subtopic interface{}, topicType interface{}, client interface{}) *Service_Subscribe_Call {
return &Service_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client)}
}
func (_c *Service_Subscribe_Call) Run(run func(ctx context.Context, sessionID string, username string, password string, domainID string, chanID string, subtopic string, topicType messaging.TopicType, client *http.Client)) *Service_Subscribe_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 string
if args[3] != nil {
arg3 = args[3].(string)
}
var arg4 string
if args[4] != nil {
arg4 = args[4].(string)
}
var arg5 string
if args[5] != nil {
arg5 = args[5].(string)
}
var arg6 string
if args[6] != nil {
arg6 = args[6].(string)
}
var arg7 messaging.TopicType
if args[7] != nil {
arg7 = args[7].(messaging.TopicType)
}
var arg8 *http.Client
if args[8] != nil {
arg8 = args[8].(*http.Client)
}
run(
arg0,
arg1,
arg2,
arg3,
arg4,
arg5,
arg6,
arg7,
arg8,
)
})
return _c
}
func (_c *Service_Subscribe_Call) Return(err error) *Service_Subscribe_Call {
_c.Call.Return(err)
return _c
}
func (_c *Service_Subscribe_Call) RunAndReturn(run func(ctx context.Context, sessionID string, username string, password string, domainID string, chanID string, subtopic string, topicType messaging.TopicType, client *http.Client) error) *Service_Subscribe_Call {
_c.Call.Return(run)
return _c
}
// Unsubscribe provides a mock function for the type Service
func (_mock *Service) Unsubscribe(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string, topicType messaging.TopicType) error {
ret := _mock.Called(ctx, sessionID, domainID, chanID, subtopic, topicType)
if len(ret) == 0 {
panic("no return value specified for Unsubscribe")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, messaging.TopicType) error); ok {
r0 = returnFunc(ctx, sessionID, domainID, chanID, subtopic, topicType)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_Unsubscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Unsubscribe'
type Service_Unsubscribe_Call struct {
*mock.Call
}
// Unsubscribe is a helper method to define mock.On call
// - ctx context.Context
// - sessionID string
// - domainID string
// - chanID string
// - subtopic string
// - topicType messaging.TopicType
func (_e *Service_Expecter) Unsubscribe(ctx interface{}, sessionID interface{}, domainID interface{}, chanID interface{}, subtopic interface{}, topicType interface{}) *Service_Unsubscribe_Call {
return &Service_Unsubscribe_Call{Call: _e.mock.On("Unsubscribe", ctx, sessionID, domainID, chanID, subtopic, topicType)}
}
func (_c *Service_Unsubscribe_Call) Run(run func(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string, topicType messaging.TopicType)) *Service_Unsubscribe_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 string
if args[3] != nil {
arg3 = args[3].(string)
}
var arg4 string
if args[4] != nil {
arg4 = args[4].(string)
}
var arg5 messaging.TopicType
if args[5] != nil {
arg5 = args[5].(messaging.TopicType)
}
run(
arg0,
arg1,
arg2,
arg3,
arg4,
arg5,
)
})
return _c
}
func (_c *Service_Unsubscribe_Call) Return(err error) *Service_Unsubscribe_Call {
_c.Call.Return(err)
return _c
}
func (_c *Service_Unsubscribe_Call) RunAndReturn(run func(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string, topicType messaging.TopicType) error) *Service_Unsubscribe_Call {
_c.Call.Return(run)
return _c
}
+8 -3
View File
@@ -24,6 +24,7 @@ import (
dmocks "github.com/absmach/supermq/domains/mocks"
adapter "github.com/absmach/supermq/http"
"github.com/absmach/supermq/http/api"
httpmocks "github.com/absmach/supermq/http/mocks"
smqlog "github.com/absmach/supermq/logger"
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/errors"
@@ -47,12 +48,14 @@ func setupMessages(t *testing.T) (*httptest.Server, *pubsub.PubSub) {
domainsGRPCClient = new(dmocks.DomainsServiceClient)
pub := new(pubsub.PubSub)
authn := new(authnmocks.Authentication)
svc := new(httpmocks.Service)
parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channelsGRPCClient, domainsGRPCClient)
assert.Nil(t, err, fmt.Sprintf("unexpected error while setting up parser: %v", err))
handler := adapter.NewHandler(pub, authn, clientsGRPCClient, channelsGRPCClient, parser, smqlog.NewMock())
handler := adapter.NewHandler(pub, smqlog.NewMock(), authn, clientsGRPCClient, channelsGRPCClient, parser)
resolver := messaging.NewTopicResolver(channelsGRPCClient, domainsGRPCClient)
mux := api.MakeHandler(smqlog.NewMock(), "")
mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), "")
target := httptest.NewServer(mux)
ptUrl, _ := url.Parse(target.URL)
@@ -186,7 +189,9 @@ func TestSendMessage(t *testing.T) {
domainsCall := domainsGRPCClient.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil)
channelsCall := channelsGRPCClient.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: channelID}}, nil)
err := mgsdk.SendMessage(context.Background(), tc.domainID, tc.topic, tc.msg, tc.secret)
assert.Equal(t, tc.err, err)
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "Publish", mock.Anything, internalTopic, mock.Anything)
assert.True(t, ok)
-4
View File
@@ -50,10 +50,6 @@ SMQ_CLIENTS_LOG_LEVEL=info SMQ_CLIENTS_HTTP_PORT=9000 SMQ_CLIENTS_GRPC_PORT=7000
###
SMQ_HTTP_ADAPTER_LOG_LEVEL=info SMQ_HTTP_ADAPTER_PORT=8008 SMQ_CLIENTS_GRPC_URL=localhost:7000 $BUILD_DIR/supermq-http &
###
# WS
###
SMQ_WS_ADAPTER_LOG_LEVEL=info SMQ_WS_ADAPTER_HTTP_PORT=8190 SMQ_CLIENTS_GRPC_URL=localhost:7000 $BUILD_DIR/supermq-ws &
###
# MQTT
+3
View File
@@ -96,6 +96,9 @@ packages:
github.com/absmach/supermq/groups/private:
interfaces:
Service:
github.com/absmach/supermq/http:
interfaces:
Service:
github.com/absmach/supermq/journal:
interfaces:
Repository:
-77
View File
@@ -1,77 +0,0 @@
# WebSocket adapter
WebSocket adapter provides a [WebSocket](https://en.wikipedia.org/wiki/WebSocket#:~:text=WebSocket%20is%20a%20computer%20communications,protocol%20is%20known%20as%20WebSockets.) API for sending and receiving messages through the platform.
## Configuration
The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values.
| Variable | Description | Default |
| --------------------------------- | ----------------------------------------------------------------------------------- | ----------------------------------- |
| SMQ_WS_ADAPTER_LOG_LEVEL | Log level for the WS Adapter (debug, info, warn, error) | info |
| SMQ_WS_ADAPTER_HTTP_HOST | Service WS host | "" |
| SMQ_WS_ADAPTER_HTTP_PORT | Service WS port | 8190 |
| SMQ_WS_ADAPTER_HTTP_SERVER_CERT | Path to the PEM encoded server certificate file | "" |
| SMQ_WS_ADAPTER_HTTP_SERVER_KEY | Path to the PEM encoded server key file | "" |
| SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS | Number of cache counters to keep that hold access frequency information | 200000 |
| SMQ_WS_ADAPTER_CACHE_MAX_COST | Maximum size of the cache(in bytes) | 1048576 |
| SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS | Number of cache `Get` buffers | 64 |
| SMQ_CLIENTS_GRPC_URL | Clients service Auth gRPC URL | <localhost:7000> |
| SMQ_CLIENTS_GRPC_TIMEOUT | Clients service Auth gRPC request timeout in seconds | 1s |
| SMQ_CLIENTS_GRPC_CLIENT_CERT | Path to the PEM encoded clients service Auth gRPC client certificate file | "" |
| SMQ_CLIENTS_GRPC_CLIENT_KEY | Path to the PEM encoded clients service Auth gRPC client key file | "" |
| SMQ_CLIENTS_GRPC_SERVER_CERTS | Path to the PEM encoded clients server Auth gRPC server trusted CA certificate file | "" |
| SMQ_MESSAGE_BROKER_URL | Message broker instance URL | <amqp://guest:guest@rabbitmq:5672/> |
| SMQ_JAEGER_URL | Jaeger server URL | <http://localhost:4318/v1/traces> |
| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 |
| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server | true |
| SMQ_WS_ADAPTER_INSTANCE_ID | Service instance ID | "" |
## Deployment
The service is distributed as Docker container. Check the [`ws-adapter`](https://github.com/absmach/supermq/blob/main/docker/docker-compose.yaml) service section in docker-compose file to see how the service is deployed.
Running this service outside of container requires working instance of the message broker service, clients service and Jaeger server.
To start the service outside of the container, execute the following shell script:
```bash
# download the latest version of the service
git clone https://github.com/absmach/supermq
cd supermq
# compile the ws
make ws
# copy binary to bin
make install
# set the environment variables and run the service
SMQ_WS_ADAPTER_LOG_LEVEL=info \
SMQ_WS_ADAPTER_HTTP_HOST=localhost \
SMQ_WS_ADAPTER_HTTP_PORT=8190 \
SMQ_WS_ADAPTER_HTTP_SERVER_CERT="" \
SMQ_WS_ADAPTER_HTTP_SERVER_KEY="" \
SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS=200000 \
SMQ_WS_ADAPTER_CACHE_MAX_COST=1048576 \
SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS=64 \
SMQ_CLIENTS_GRPC_URL=localhost:7000 \
SMQ_CLIENTS_GRPC_TIMEOUT=1s \
SMQ_CLIENTS_GRPC_CLIENT_CERT="" \
SMQ_CLIENTS_GRPC_CLIENT_KEY="" \
SMQ_CLIENTS_GRPC_SERVER_CERTS="" \
SMQ_MESSAGE_BROKER_URL=amqp://guest:guest@rabbitmq:5672/ \
SMQ_JAEGER_URL=http://localhost:14268/api/traces \
SMQ_JAEGER_TRACE_RATIO=1.0 \
SMQ_SEND_TELEMETRY=true \
SMQ_WS_ADAPTER_INSTANCE_ID="" \
$GOBIN/supermq-ws
```
Setting `SMQ_WS_ADAPTER_HTTP_SERVER_CERT` and `SMQ_WS_ADAPTER_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key.
Setting `SMQ_CLIENTS_GRPC_CLIENT_CERT` and `SMQ_CLIENTS_GRPC_CLIENT_KEY` will enable TLS against the clients service. The service expects a file in PEM format for both the certificate and the key. Setting `SMQ_CLIENTS_GRPC_SERVER_CERTS` will enable TLS against the clients service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs.
## Usage
For more information about service capabilities and its usage, please check out the [WebSocket section](https://docs.supermq.absmach.eu/messaging/#websocket).
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package api contains API-related concerns: endpoint definitions, middlewares
// and all resource representations.
package api
-298
View File
@@ -1,298 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api_test
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/absmach/mgate"
mHttp "github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
chmocks "github.com/absmach/supermq/channels/mocks"
climocks "github.com/absmach/supermq/clients/mocks"
dmocks "github.com/absmach/supermq/domains/mocks"
"github.com/absmach/supermq/internal/testsutil"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnMocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/absmach/supermq/ws"
"github.com/absmach/supermq/ws/api"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
clientKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529"
protocol = "ws"
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
validToken = "Bearer token"
)
var (
msg = []byte(`[{"n":"current","t":-1,"v":1.6}]`)
domainID = testsutil.GenerateUUID(&testing.T{})
id = testsutil.GenerateUUID(&testing.T{})
)
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) (ws.Service, *mocks.PubSub) {
pubsub := new(mocks.PubSub)
authn := new(authnMocks.Authentication)
return ws.New(clients, channels, authn, pubsub), pubsub
}
func newHTTPServer(svc ws.Service, resolver messaging.TopicResolver) *httptest.Server {
mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), instanceID)
return httptest.NewServer(mux)
}
func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*httptest.Server, error) {
turl := strings.ReplaceAll(targetServer.URL, "http", "ws")
ptUrl, _ := url.Parse(turl)
ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host)
config := mgate.Config{
Host: "",
Port: "",
PathPrefix: "",
TargetHost: ptHost,
TargetPort: ptPort,
TargetProtocol: ptUrl.Scheme,
TargetPath: ptUrl.Path,
}
mp, err := mHttp.NewProxy(config, svc, smqlog.NewMock(), []string{}, []string{})
if err != nil {
return nil, err
}
return httptest.NewServer(http.HandlerFunc(mp.ServeHTTP)), nil
}
func makeURL(tsURL, domainID, chanID, subtopic, authKey string, header bool, healthCheck bool) (string, error) {
u, _ := url.Parse(tsURL)
u.Scheme = protocol
if healthCheck {
if header {
return fmt.Sprintf("%s/hc/%s", u, domainID), nil
}
return fmt.Sprintf("%s/hc/%s?authorization=%s", u, domainID, authKey), nil
}
if chanID == "0" || chanID == "" {
if header {
return fmt.Sprintf("%s/m/%s/c/%s", u, domainID, chanID), fmt.Errorf("invalid channel id")
}
return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, authKey), fmt.Errorf("invalid channel id")
}
subtopicPart := ""
if subtopic != "" {
subtopicPart = fmt.Sprintf("/%s", subtopic)
}
if header {
return fmt.Sprintf("%s/m/%s/c/%s%s", u, domainID, chanID, subtopicPart), nil
}
return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, authKey), nil
}
func handshake(tsURL, domainID, chanID, subtopic, authKey string, addHeader bool, healthCheck bool) (*websocket.Conn, *http.Response, error) {
header := http.Header{}
if addHeader {
header.Add("Authorization", authKey)
}
turl, _ := makeURL(tsURL, domainID, chanID, subtopic, authKey, addHeader, healthCheck)
conn, res, errRet := websocket.DefaultDialer.Dial(turl, header)
return conn, res, errRet
}
func TestHandshake(t *testing.T) {
clients := new(climocks.ClientsServiceClient)
channels := new(chmocks.ChannelsServiceClient)
authn := new(authnMocks.Authentication)
domains := new(dmocks.DomainsServiceClient)
resolver := messaging.NewTopicResolver(channels, domains)
parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains)
require.Nil(t, err, fmt.Sprintf("unexpected error while setting up parser: %v", err))
svc, pubsub := newService(clients, channels)
target := newHTTPServer(svc, resolver)
defer target.Close()
handler := ws.NewHandler(pubsub, smqlog.NewMock(), authn, clients, channels, parser)
ts, err := newProxyHTPPServer(handler, target)
require.Nil(t, err)
defer ts.Close()
pubsub.On("Subscribe", mock.Anything, mock.Anything).Return(nil)
pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil)
pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
clients.On("Authenticate", mock.Anything, mock.MatchedBy(func(req *grpcClientsV1.AuthnReq) bool {
_, _, key, _ := smqauthn.AuthUnpack(req.Token)
return key == clientKey
})).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil)
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil)
authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil)
channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil)
cases := []struct {
desc string
domainID string
chanID string
subtopic string
header bool
authKey string
status int
healthCheck bool
err error
msg []byte
}{
{
desc: "connect and send message",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message with clientKey as query parameter",
domainID: domainID,
chanID: id,
subtopic: "",
header: false,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message with valid token",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
authKey: validToken,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message that cannot be published",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: []byte{},
},
{
desc: "connect and send message to subtopic",
domainID: domainID,
chanID: id,
subtopic: "subtopic",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message to nested subtopic",
domainID: domainID,
chanID: id,
subtopic: "subtopic/nested",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message to all subtopics",
domainID: domainID,
chanID: id,
subtopic: ">",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect to empty channel",
domainID: domainID,
chanID: "",
subtopic: "",
header: true,
authKey: clientKey,
status: http.StatusUnauthorized,
msg: []byte{},
},
{
desc: "connect with empty clientKey",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
authKey: "",
status: http.StatusUnauthorized,
msg: []byte{},
},
{
desc: "connect and send message to subtopic with invalid name",
domainID: domainID,
chanID: id,
subtopic: "sub/a*b/topic",
header: true,
authKey: clientKey,
status: http.StatusUnauthorized,
msg: msg,
},
{
desc: "connect with health check and send message",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
healthCheck: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect with health check and send message with clientKey as query parameter",
domainID: domainID,
chanID: id,
subtopic: "",
header: false,
healthCheck: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.authKey, tc.header, tc.healthCheck)
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode))
if tc.status == http.StatusSwitchingProtocols {
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err))
err = conn.WriteMessage(websocket.TextMessage, tc.msg)
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err))
}
})
}
}
-130
View File
@@ -1,130 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"log/slog"
"net/http"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/ws"
"github.com/go-chi/chi/v5"
)
const (
authzHeaderKey = "Authorization"
authzQueryKey = "authorization"
)
var errGenSessionID = errors.New("failed to generate session id")
func generateSessionID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", errors.Wrap(errGenSessionID, err)
}
return hex.EncodeToString(b), nil
}
func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicResolver, topicType messaging.TopicType, logger *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r, resolver, logger)
if err != nil {
encodeError(w, err)
return
}
sessionID, err := generateSessionID()
if err != nil {
logger.Warn(fmt.Sprintf("Failed to generate session id: %s", err.Error()))
http.Error(w, "", http.StatusInternalServerError)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error()))
return
}
client := ws.NewClient(logger, conn, sessionID)
client.SetCloseHandler(func(code int, text string) error {
return svc.Unsubscribe(ctx, sessionID, req.domainID, req.channelID, req.subtopic, topicType)
})
go client.Start(ctx)
if err := svc.Subscribe(ctx, sessionID, req.username, req.password, req.domainID, req.channelID, req.subtopic, topicType, client); err != nil {
conn.Close()
return
}
logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channelID))
}
}
func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) {
username, password, ok := r.BasicAuth()
if !ok {
switch {
case r.URL.Query().Get(authzQueryKey) != "":
password = r.URL.Query().Get(authzQueryKey)
case r.Header.Get(authzHeaderKey) != "":
password = r.Header.Get(authzHeaderKey)
default:
logger.Debug("Missing authorization key.")
return connReq{}, errUnauthorizedAccess
}
}
domain := chi.URLParam(r, "domain")
channel := chi.URLParam(r, "channel")
domainID, channelID, _, err := resolver.Resolve(r.Context(), domain, channel)
if err != nil {
return connReq{}, err
}
req := connReq{
username: username,
password: password,
channelID: channelID,
domainID: domainID,
}
subTopic := chi.URLParam(r, "*")
if subTopic != "" {
subTopic, err := messaging.ParseSubscribeSubtopic(subTopic)
if err != nil {
return connReq{}, err
}
req.subtopic = subTopic
}
return req, nil
}
func encodeError(w http.ResponseWriter, err error) {
var statusCode int
switch err {
case ws.ErrEmptyTopic:
statusCode = http.StatusBadRequest
case errUnauthorizedAccess:
statusCode = http.StatusForbidden
case errMalformedSubtopic, errors.ErrMalformedEntity:
statusCode = http.StatusBadRequest
default:
statusCode = http.StatusNotFound
}
logger.Warn(fmt.Sprintf("Failed to authorize: %s", err.Error()))
w.WriteHeader(statusCode)
}
-12
View File
@@ -1,12 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
type connReq struct {
username string
password string
channelID string
domainID string
subtopic string
}
-51
View File
@@ -1,51 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"context"
"errors"
"log/slog"
"net/http"
"github.com/absmach/supermq"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/ws"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
const (
service = "ws"
readwriteBufferSize = 1024
)
var (
errUnauthorizedAccess = errors.New("missing or invalid credentials provided")
errMalformedSubtopic = errors.New("malformed subtopic")
)
var (
upgrader = websocket.Upgrader{
ReadBufferSize: readwriteBufferSize,
WriteBufferSize: readwriteBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
}
logger *slog.Logger
)
// MakeHandler returns http handler with handshake endpoint.
func MakeHandler(ctx context.Context, svc ws.Service, resolver messaging.TopicResolver, l *slog.Logger, instanceID string) http.Handler {
logger = l
mux := chi.NewRouter()
mux.Get("/m/{domain}/c/{channel}", handshake(ctx, svc, resolver, messaging.MessageType, l))
mux.Get("/m/{domain}/c/{channel}/*", handshake(ctx, svc, resolver, messaging.MessageType, l))
mux.Get("/hc/{domain}", handshake(ctx, svc, resolver, messaging.HealthType, l))
mux.Get("/health", supermq.Health(service, instanceID))
mux.Handle("/metrics", promhttp.Handler())
return mux
}
-15
View File
@@ -1,15 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package ws provides domain concept definitions required to support
// SuperMQ WebSocket adapter service functionality.
//
// This package defines the core domain concepts and types necessary to handle
// WebSocket connections and messages in the context of a SuperMQ WebSocket
// adapter service. It abstracts the underlying complexities of WebSocket
// communication and provides a structured approach to working with WebSocket
// clients and servers.
//
// For more details about SuperMQ messaging and WebSocket adapter service,
// please refer to the documentation at https://docs.supermq.absmach.eu/messaging/#websocket.
package ws
-281
View File
@@ -1,281 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package ws
import (
"context"
"encoding/base64"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
mgate "github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
apiutil "github.com/absmach/supermq/api/http/util"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/policies"
)
var _ session.Handler = (*handler)(nil)
const protocol = "websocket"
// Log message formats.
const (
LogInfoSubscribed = "subscribed with client_id %s to topics %s"
LogInfoConnected = "connected with client_id %s"
LogInfoDisconnected = "disconnected client_id %s and username %s"
LogInfoPublished = "published with client_id %s to the topic %s"
)
// Error wrappers for MQTT errors.
var (
errClientNotInitialized = errors.New("client is not initialized")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
errFailedPublish = errors.New("failed to publish")
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
errInvalidAuthFormat = errors.New("invalid basic auth format")
errInvalidClientType = errors.New("invalid client type")
)
// Event implements events.Event interface.
type handler struct {
pubsub messaging.PubSub
clients grpcClientsV1.ClientsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
authn smqauthn.Authentication
logger *slog.Logger
parser messaging.TopicParser
}
// NewHandler creates new Handler entity.
func NewHandler(pubsub messaging.PubSub, logger *slog.Logger, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, parser messaging.TopicParser) session.Handler {
return &handler{
logger: logger,
pubsub: pubsub,
authn: authn,
clients: clients,
channels: channels,
parser: parser,
}
}
// AuthConnect is called on device connection,
// prior forwarding to the ws server.
func (h *handler) AuthConnect(ctx context.Context) error {
return nil
}
// AuthPublish is called on device publish,
// prior forwarding to the ws server.
func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error {
if topic == nil {
return errMissingTopicPub
}
s, ok := session.FromContext(ctx)
if !ok {
return errClientNotInitialized
}
domainID, channelID, _, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, true)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
clientID, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Publish, topicType)
if err != nil {
return err
}
if s.Username == "" {
s.Username = clientID
}
return nil
}
// AuthSubscribe is called on device publish,
// prior forwarding to the MQTT broker.
func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
s, ok := session.FromContext(ctx)
if !ok {
return errClientNotInitialized
}
if topics == nil || *topics == nil {
return errMissingTopicSub
}
for _, topic := range *topics {
domainID, channelID, _, topicType, err := h.parser.ParseSubscribeTopic(ctx, topic, true)
if err != nil {
return err
}
if _, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Subscribe, topicType); err != nil {
return err
}
}
return nil
}
// Connect - after client successfully connected.
func (h *handler) Connect(ctx context.Context) error {
return nil
}
// Publish - after client successfully published.
func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) error {
s, ok := session.FromContext(ctx)
if !ok {
return errClientNotInitialized
}
if len(*payload) == 0 {
h.logger.Warn("Empty payload, not publishing to broker", slog.String("client_id", s.Username))
return nil
}
domainID, channelID, subtopic, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, true)
if err != nil {
return errors.Wrap(errFailedPublish, err)
}
msg := messaging.Message{
Protocol: protocol,
Domain: domainID,
Channel: channelID,
Subtopic: subtopic,
Payload: *payload,
Publisher: s.Username,
Created: time.Now().UnixNano(),
}
// Health check topic messages do not get published to message broker.
if topicType == messaging.MessageType {
if err := h.pubsub.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil {
return mgate.NewHTTPProxyError(http.StatusInternalServerError, errors.Wrap(errFailedPublishToMsgBroker, err))
}
}
h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic))
return nil
}
// Subscribe - after client successfully subscribed.
func (h *handler) Subscribe(ctx context.Context, topics *[]string) error {
s, ok := session.FromContext(ctx)
if !ok {
return errClientNotInitialized
}
h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ",")))
return nil
}
// Unsubscribe - after client unsubscribed.
func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
return nil
}
// Disconnect - connection with broker or client lost.
func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func (h *handler) authAccess(ctx context.Context, username, password, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
var token, clientType string
var err error
switch {
case strings.HasPrefix(password, apiutil.BearerPrefix):
token = strings.TrimPrefix(password, apiutil.BearerPrefix)
clientType = policies.UserType
case username != "" && password != "":
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
case strings.HasPrefix(password, apiutil.BasicAuthPrefix):
username, password, err := decodeAuth(strings.TrimPrefix(password, apiutil.BasicAuthPrefix))
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
default:
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(password, apiutil.ClientPrefix))
clientType = policies.ClientType
}
id, err := h.authenticate(ctx, clientType, token)
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
// Health check topics do not require channel authorization.
if topicType == messaging.HealthType {
return id, nil
}
ar := &grpcChannelsV1.AuthzReq{
Type: uint32(msgType),
ClientId: id,
ClientType: clientType,
ChannelId: chanID,
DomainId: domainID,
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
if !res.GetAuthorized() {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
return id, nil
}
func (h *handler) authenticate(ctx context.Context, authType, token string) (string, error) {
switch authType {
case policies.UserType:
authnSession, err := h.authn.Authenticate(ctx, token)
if err != nil {
return "", err
}
return authnSession.UserID, nil
case policies.ClientType:
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token})
if err != nil {
return "", err
}
if !authnRes.Authenticated {
return "", svcerr.ErrAuthentication
}
return authnRes.GetId(), nil
default:
return "", errInvalidClientType
}
}
// decodeAuth decodes the base64 encoded string in the format "clientID:secret".
func decodeAuth(s string) (string, string, error) {
db, err := base64.URLEncoding.DecodeString(s)
if err != nil {
return "", "", err
}
parts := strings.SplitN(string(db), ":", 2)
if len(parts) != 2 {
return "", "", errInvalidAuthFormat
}
clientID := parts[0]
secret := parts[1]
return clientID, secret, nil
}
-738
View File
@@ -1,738 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package ws_test
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strings"
"testing"
mgate "github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
apiutil "github.com/absmach/supermq/api/http/util"
chmocks "github.com/absmach/supermq/channels/mocks"
clmocks "github.com/absmach/supermq/clients/mocks"
dmocks "github.com/absmach/supermq/domains/mocks"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/ws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
var (
invalidValue = "invalid"
topicMsg = "/m/%s/c/%s"
subtopicMsg = "/m/%s/c/%s/subtopic"
topic = fmt.Sprintf(topicMsg, domainID, chanID)
subtopic = fmt.Sprintf(subtopicMsg, domainID, chanID)
hcTopicFmt = "/hc/%s"
hcTopic = fmt.Sprintf(hcTopicFmt, domainID)
invalidHCTopic = "/hc"
invalidTopic = invalidValue
topics = []string{topic}
payload = []byte("[{'n':'test-name', 'v': 1.2}]")
sessionClient = session.Session{
ID: clientID,
Password: []byte(clientKey),
}
invalidChannelIDTopic = "m/**/c"
validToken = "token"
errClientNotInitialized = errors.New("client is not initialized")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
)
var (
clients = new(clmocks.ClientsServiceClient)
channels = new(chmocks.ChannelsServiceClient)
authn = new(authnmocks.Authentication)
publisher = new(mocks.PubSub)
domains = new(dmocks.DomainsServiceClient)
)
func newHandler(t *testing.T) session.Handler {
logger := smqlog.NewMock()
authn = new(authnmocks.Authentication)
clients = new(clmocks.ClientsServiceClient)
channels = new(chmocks.ChannelsServiceClient)
publisher = new(mocks.PubSub)
parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains)
assert.Nil(t, err, fmt.Sprintf("unexpected error while creating topic parser: %v", err))
return ws.NewHandler(publisher, logger, authn, clients, channels, parser)
}
func TestAuthPublish(t *testing.T) {
handler := newHandler(t)
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
unauthorizedKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
invalidClientKeySession := session.Session{
Password: []byte("Client " + invalidKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
basicAuthSession := session.Session{
Username: clientID,
Password: []byte(clientKey),
}
invalidBasicAuthSession := session.Session{
Username: clientID,
Password: []byte(invalidValue),
}
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
encodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + creds),
}
invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue)))
invalidEncodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + invalidCreds),
}
hcClientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
tests := []struct {
desc string
session *session.Session
topic *string
payload *[]byte
authKey string
status int
clientType string
chanID string
domainID string
clientID string
authNToken string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
authZRes *grpcChannelsV1.AuthzRes
authZErr error
err error
}{
{
desc: "publish with client key successfully",
session: &clientKeySession,
topic: &topic,
authKey: clientKey,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid client key",
session: &invalidClientKeySession,
topic: &topic,
authKey: invalidKey,
payload: &payload,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with nil session",
session: nil,
topic: &topic,
authKey: clientKey,
status: http.StatusInternalServerError,
err: errClientNotInitialized,
},
{
desc: "publish with empty topic",
session: &clientKeySession,
topic: nil,
authKey: clientKey,
status: http.StatusBadRequest,
err: errMissingTopicPub,
},
{
desc: "publish with unauthorized client key",
session: &unauthorizedKeySession,
topic: &topic,
authKey: clientKey,
payload: &payload,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with token successfully",
session: &tokenSession,
topic: &topic,
authKey: token,
payload: &payload,
status: http.StatusOK,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid token",
session: &invalidTokenSession,
topic: &topic,
authKey: invalidToken,
payload: &payload,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with unauthorized token",
session: &tokenSession,
topic: &topic,
authKey: token,
payload: &payload,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with basic auth successfully",
session: &basicAuthSession,
topic: &topic,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid basic auth",
session: &invalidBasicAuthSession,
topic: &topic,
payload: &payload,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with b64 encoded credentials",
session: &encodedCredsSession,
topic: &topic,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid b64 encoded credentials",
session: &invalidEncodedCredsSession,
topic: &topic,
payload: &payload,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with health check topic successfully",
session: &hcClientKeySession,
topic: &hcTopic,
authKey: clientKey,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: "",
domainID: domainID,
clientID: userID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: true},
authNErr: nil,
},
{
desc: "publish with invalid health check topic",
session: &hcClientKeySession,
topic: &invalidHCTopic,
authKey: clientKey,
payload: &payload,
status: http.StatusBadRequest,
err: messaging.ErrMalformedTopic,
clientType: policies.ClientType,
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
tc.clientType = policies.ClientType
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Publish),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
err := handler.AuthPublish(ctx, tc.topic, tc.payload)
hpe, ok := err.(mgate.HTTPProxyError)
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
authCall.Unset()
clientsCall.Unset()
channelsCall.Unset()
})
}
}
func TestAuthSubscribe(t *testing.T) {
handler := newHandler(t)
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
unauthorizedKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
invalidClientKeySession := session.Session{
Password: []byte("Client " + invalidKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
basicAuthSession := session.Session{
Username: clientID,
Password: []byte(clientKey),
}
invalidBasicAuthSession := session.Session{
Username: clientID,
Password: []byte(invalidValue),
}
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
encodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + creds),
}
invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue)))
invalidEncodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + invalidCreds),
}
hcClientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
tests := []struct {
desc string
session *session.Session
topics *[]string
authKey string
status int
clientType string
chanID string
domainID string
clientID string
authNToken string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
authZRes *grpcChannelsV1.AuthzRes
authZErr error
err error
}{
{
desc: "subscribe with client key successfully",
session: &clientKeySession,
topics: &topics,
authKey: clientKey,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid client key",
session: &invalidClientKeySession,
topics: &topics,
authKey: invalidKey,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with empty topics",
session: &clientKeySession,
topics: nil,
authKey: clientKey,
status: http.StatusBadRequest,
err: errMissingTopicSub,
},
{
desc: "subscribe with nil session",
session: nil,
topics: &topics,
authKey: clientKey,
status: http.StatusInternalServerError,
err: errClientNotInitialized,
},
{
desc: "subscribe with unauthorized client key",
session: &unauthorizedKeySession,
topics: &topics,
authKey: clientKey,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with token successfully",
session: &tokenSession,
topics: &topics,
authKey: token,
status: http.StatusOK,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid token",
session: &invalidTokenSession,
topics: &topics,
authKey: invalidToken,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with unauthorized token",
session: &tokenSession,
topics: &topics,
authKey: token,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with basic auth successfully",
session: &basicAuthSession,
topics: &topics,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid basic auth",
session: &invalidBasicAuthSession,
topics: &topics,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with b64 encoded credentials",
session: &encodedCredsSession,
topics: &topics,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid b64 encoded credentials",
session: &invalidEncodedCredsSession,
topics: &topics,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with health check topic successfully",
session: &hcClientKeySession,
topics: &[]string{hcTopic},
authKey: clientKey,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
err: nil,
},
{
desc: "subscribe with invalid health check topic",
session: &hcClientKeySession,
topics: &[]string{invalidHCTopic},
authKey: clientKey,
status: http.StatusBadRequest,
err: messaging.ErrMalformedTopic,
clientType: policies.ClientType,
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
tc.clientType = policies.ClientType
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Subscribe),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
err := handler.AuthSubscribe(ctx, tc.topics)
hpe, ok := err.(mgate.HTTPProxyError)
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
authCall.Unset()
clientsCall.Unset()
channelsCall.Unset()
})
}
}
func TestPublish(t *testing.T) {
handler := newHandler(t)
malformedSubtopics := topic + "/" + subtopic + "%"
wrongCharSubtopics := topic + "/" + subtopic + ">"
validSubtopic := topic + "/" + subtopic
cases := []struct {
desc string
session *session.Session
topic string
payload []byte
err error
}{
{
desc: "publish without active session",
session: nil,
topic: topic,
payload: payload,
err: errClientNotInitialized,
},
{
desc: "publish with invalid topic",
session: &sessionClient,
topic: invalidTopic,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with invalid channel ID",
session: &sessionClient,
topic: invalidChannelIDTopic,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with malformed subtopic",
session: &sessionClient,
topic: malformedSubtopics,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with subtopic containing wrong character",
session: &sessionClient,
topic: wrongCharSubtopics,
payload: payload,
err: messaging.ErrMalformedSubtopic,
},
{
desc: "publish with subtopic",
session: &sessionClient,
topic: validSubtopic,
payload: payload,
},
{
desc: "publish without subtopic",
session: &sessionClient,
topic: topic,
payload: payload,
},
{
desc: "publish with health check topic",
session: &sessionClient,
topic: hcTopic,
payload: payload,
},
{
desc: "puvlish with invalid health check topic",
session: &sessionClient,
topic: invalidHCTopic,
payload: payload,
err: messaging.ErrMalformedTopic,
},
}
for _, tc := range cases {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := handler.Publish(ctx, &tc.topic, &tc.payload)
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
repoCall.Unset()
}
}