SMQ - 2546 - Add telemetry aggregation for clients telemetry (#2661)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-02-03 13:37:25 +03:00
committed by GitHub
parent 07dbb86203
commit 484de372ec
19 changed files with 758 additions and 217 deletions
@@ -61,7 +61,6 @@ jobs:
- "invitations/invitations.go"
- "users/emailer.go"
- "users/hasher.go"
- "mqtt/events/streams.go"
- "certs/certs.go"
- "certs/pki/vault.go"
- "certs/service.go"
@@ -149,7 +148,6 @@ jobs:
mv ./clients/mocks/clients_client.go ./clients/mocks/clients_client.go.tmp
mv ./clients/mocks/cache.go ./clients/mocks/cache.go.tmp
mv ./clients/mocks/service.go ./clients/mocks/service.go.tmp
mv ./mqtt/mocks/events.go ./mqtt/mocks/events.go.tmp
mv ./readers/mocks/messages.go ./readers/mocks/messages.go.tmp
mv ./pkg/sdk/mocks/sdk.go ./pkg/sdk/mocks/sdk.go.tmp
mv ./pkg/messaging/mocks/pubsub.go ./pkg/messaging/mocks/pubsub.go.tmp
@@ -208,7 +206,6 @@ jobs:
check_mock_changes ./clients/mocks/clients_client.go " ./clients/mocks/clients_client.go"
check_mock_changes ./clients/mocks/cache.go " ./clients/mocks/cache.go"
check_mock_changes ./clients/mocks/service.go " ./clients/mocks/service.go"
check_mock_changes ./mqtt/mocks/events.go " ./mqtt/mocks/events.go"
check_mock_changes ./readers/mocks/messages.go " ./readers/mocks/messages.go"
check_mock_changes ./pkg/sdk/mocks/sdk.go " ./pkg/sdk/mocks/sdk.go"
check_mock_changes ./pkg/messaging/mocks/pubsub.go " ./pkg/messaging/mocks/pubsub.go"
+9 -15
View File
@@ -135,13 +135,6 @@ func main() {
defer bsub.Close()
bsub = brokerstracing.NewPubSub(serverConfig, tracer, bsub)
bsub, err = msgevents.NewPubSubMiddleware(ctx, bsub, cfg.ESURL)
if err != nil {
logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err))
exitCode = 1
return
}
mpub, err := mqttpub.NewPublisher(fmt.Sprintf("mqtt://%s:%s", cfg.MQTTTargetHost, cfg.MQTTTargetPort), cfg.MQTTQoS, cfg.MQTTForwarderTimeout)
if err != nil {
logger.Error(fmt.Sprintf("failed to create MQTT publisher: %s", err))
@@ -181,13 +174,6 @@ func main() {
return
}
es, err := events.NewEventStore(ctx, cfg.ESURL, cfg.Instance)
if err != nil {
logger.Error(fmt.Sprintf("failed to create %s event store : %s", svcName, err))
exitCode = 1
return
}
clientsClientCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err))
@@ -220,7 +206,15 @@ func main() {
defer channelsHandler.Close()
logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure())
h := mqtt.NewHandler(np, es, logger, clientsClient, channelsClient)
h := mqtt.NewHandler(np, logger, clientsClient, channelsClient)
h, err = events.NewEventStoreMiddleware(ctx, h, cfg.ESURL, cfg.Instance)
if err != nil {
logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err))
exitCode = 1
return
}
h = handler.NewTracing(tracer, h)
if cfg.SendTelemetry {
+4 -3
View File
@@ -121,9 +121,10 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, chanID, subtopic
authzc := newAuthzClient(clientID, chanID, subtopic, svc.channels, c)
subCfg := messaging.SubscriberConfig{
ID: c.Token(),
Topic: subject,
Handler: authzc,
ID: c.Token(),
ClientID: clientID,
Topic: subject,
Handler: authzc,
}
return svc.pubsub.Subscribe(ctx, subCfg)
}
+16 -1
View File
@@ -10,7 +10,10 @@ import (
"github.com/absmach/supermq/journal"
)
var _ supermq.Response = (*pageRes)(nil)
var (
_ supermq.Response = (*pageRes)(nil)
_ supermq.Response = (*clientTelemetryRes)(nil)
)
type pageRes struct {
journal.JournalsPage `json:",inline"`
@@ -31,3 +34,15 @@ func (res pageRes) Empty() bool {
type clientTelemetryRes struct {
journal.ClientTelemetry `json:",inline"`
}
func (res clientTelemetryRes) Headers() map[string]string {
return map[string]string{}
}
func (res clientTelemetryRes) Code() int {
return http.StatusOK
}
func (res clientTelemetryRes) Empty() bool {
return false
}
+24 -1
View File
@@ -140,13 +140,21 @@ func (page JournalsPage) MarshalJSON() ([]byte, error) {
type ClientTelemetry struct {
ClientID string `json:"client_id"`
DomainID string `json:"domain_id"`
Subscriptions []string `json:"subscriptions"`
Subscriptions uint64 `json:"subscriptions"`
InboundMessages uint64 `json:"inbound_messages"`
OutboundMessages uint64 `json:"outbound_messages"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
}
type ClientSubscription struct {
ID string `json:"id" db:"id"`
SubscriberID string `json:"subscriber_id" db:"subscriber_id"`
ChannelID string `json:"channel_id" db:"channel_id"`
Subtopic string `json:"subtopic" db:"subtopic"`
ClientID string `json:"client_id" db:"client_id"`
}
// Service provides access to the journal log service.
//
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines"
@@ -179,4 +187,19 @@ type Repository interface {
// DeleteClientTelemetry removes telemetry data for a client from the database.
DeleteClientTelemetry(ctx context.Context, clientID, domainID string) error
// AddSubscription adds a subscription to the client telemetry.
AddSubscription(ctx context.Context, sub ClientSubscription) error
// CountSubscriptions returns the number of subscriptions for a client.
CountSubscriptions(ctx context.Context, clientID string) (uint64, error)
// RemoveSubscription removes a subscription from the client telemetry.
RemoveSubscription(ctx context.Context, subscriberID string) error
// IncrementInboundMessages increments the inbound messages count for a client.
IncrementInboundMessages(ctx context.Context, clientID string) error
// IncrementOutboundMessages increments the outbound messages count for a client.
IncrementOutboundMessages(ctx context.Context, channelID, subtopic string) error
}
+1 -1
View File
@@ -70,7 +70,7 @@ func (am *authorizationMiddleware) RetrieveClientTelemetry(ctx context.Context,
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: session.UserID,
Subject: session.DomainUserID,
Permission: readPermission,
ObjectType: policies.ClientType,
Object: clientID,
+100
View File
@@ -16,6 +16,52 @@ type Repository struct {
mock.Mock
}
// AddSubscription provides a mock function with given fields: ctx, sub
func (_m *Repository) AddSubscription(ctx context.Context, sub journal.ClientSubscription) error {
ret := _m.Called(ctx, sub)
if len(ret) == 0 {
panic("no return value specified for AddSubscription")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, journal.ClientSubscription) error); ok {
r0 = rf(ctx, sub)
} else {
r0 = ret.Error(0)
}
return r0
}
// CountSubscriptions provides a mock function with given fields: ctx, clientID
func (_m *Repository) CountSubscriptions(ctx context.Context, clientID string) (uint64, error) {
ret := _m.Called(ctx, clientID)
if len(ret) == 0 {
panic("no return value specified for CountSubscriptions")
}
var r0 uint64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (uint64, error)); ok {
return rf(ctx, clientID)
}
if rf, ok := ret.Get(0).(func(context.Context, string) uint64); ok {
r0 = rf(ctx, clientID)
} else {
r0 = ret.Get(0).(uint64)
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, clientID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// DeleteClientTelemetry provides a mock function with given fields: ctx, clientID, domainID
func (_m *Repository) DeleteClientTelemetry(ctx context.Context, clientID string, domainID string) error {
ret := _m.Called(ctx, clientID, domainID)
@@ -34,6 +80,60 @@ func (_m *Repository) DeleteClientTelemetry(ctx context.Context, clientID string
return r0
}
// IncrementInboundMessages provides a mock function with given fields: ctx, clientID
func (_m *Repository) IncrementInboundMessages(ctx context.Context, clientID string) error {
ret := _m.Called(ctx, clientID)
if len(ret) == 0 {
panic("no return value specified for IncrementInboundMessages")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, clientID)
} else {
r0 = ret.Error(0)
}
return r0
}
// IncrementOutboundMessages provides a mock function with given fields: ctx, channelID, subtopic
func (_m *Repository) IncrementOutboundMessages(ctx context.Context, channelID string, subtopic string) error {
ret := _m.Called(ctx, channelID, subtopic)
if len(ret) == 0 {
panic("no return value specified for IncrementOutboundMessages")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, channelID, subtopic)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveSubscription provides a mock function with given fields: ctx, subscriberID
func (_m *Repository) RemoveSubscription(ctx context.Context, subscriberID string) error {
ret := _m.Called(ctx, subscriberID)
if len(ret) == 0 {
panic("no return value specified for RemoveSubscription")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, subscriberID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RetrieveAll provides a mock function with given fields: ctx, page
func (_m *Repository) RetrieveAll(ctx context.Context, page journal.Page) (journal.JournalsPage, error) {
ret := _m.Called(ctx, page)
+11 -4
View File
@@ -28,18 +28,25 @@ func Migration() *migrate.MemoryMigrationSource {
`CREATE INDEX idx_journal_default_client_filter ON journal(operation, (attributes->>'id'), (attributes->>'client_id'), occurred_at DESC);`,
`CREATE INDEX idx_journal_default_channel_filter ON journal(operation, (attributes->>'id'), (attributes->>'channel_id'), occurred_at DESC);`,
`CREATE TABLE IF NOT EXISTS clients_telemetry (
client_id VARCHAR(36) NOT NULL,
client_id VARCHAR(36) PRIMARY KEY,
domain_id VARCHAR(36) NOT NULL,
subscriptions TEXT[],
inbound_messages BIGINT DEFAULT 0,
outbound_messages BIGINT DEFAULT 0,
first_seen TIMESTAMP,
last_seen TIMESTAMP,
PRIMARY KEY (client_id, domain_id)
last_seen TIMESTAMP
)`,
`CREATE TABLE IF NOT EXISTS subscriptions (
id VARCHAR(36) PRIMARY KEY,
subscriber_id VARCHAR(1024) NOT NULL,
channel_id VARCHAR(36) NOT NULL,
subtopic VARCHAR(1024),
client_id VARCHAR(36),
FOREIGN KEY (client_id) REFERENCES clients_telemetry(client_id) ON DELETE CASCADE ON UPDATE CASCADE
)`,
},
Down: []string{
`DROP TABLE IF EXISTS clients_telemetry`,
`DROP TABLE IF EXISTS subscriptions`,
`DROP TABLE IF EXISTS journal`,
},
},
+139 -18
View File
@@ -16,8 +16,8 @@ import (
)
func (repo *repository) SaveClientTelemetry(ctx context.Context, ct journal.ClientTelemetry) error {
q := `INSERT INTO clients_telemetry (client_id, domain_id, messages, subscriptions, first_seen, last_seen)
VALUES (:client_id, :domain_id, :messages, :subscriptions, :first_seen, :last_seen);`
q := `INSERT INTO clients_telemetry (client_id, domain_id, inbound_messages, outbound_messages, first_seen, last_seen)
VALUES (:client_id, :domain_id, :inbound_messages, :outbound_messages, :first_seen, :last_seen);`
dbct, err := toDBClientsTelemetry(ct)
if err != nil {
@@ -32,7 +32,7 @@ func (repo *repository) SaveClientTelemetry(ctx context.Context, ct journal.Clie
}
func (repo *repository) DeleteClientTelemetry(ctx context.Context, clientID, domainID string) error {
q := "DELETE FROM clients_telemetry AS ct WHERE ct.client_id = :client_id AND ct.domain_id = :domain_id;"
q := `DELETE FROM clients_telemetry AS ct WHERE ct.client_id = :client_id AND ct.domain_id = :domain_id;`
dbct := dbClientTelemetry{
ClientID: clientID,
@@ -50,7 +50,7 @@ func (repo *repository) DeleteClientTelemetry(ctx context.Context, clientID, dom
}
func (repo *repository) RetrieveClientTelemetry(ctx context.Context, clientID, domainID string) (journal.ClientTelemetry, error) {
q := "SELECT * FROM clients_telemetry WHERE client_id = :client_id AND domain_id = :domain_id;"
q := `SELECT * FROM clients_telemetry WHERE client_id = :client_id AND domain_id = :domain_id;`
dbct := dbClientTelemetry{
ClientID: clientID,
@@ -80,14 +80,142 @@ func (repo *repository) RetrieveClientTelemetry(ctx context.Context, clientID, d
return journal.ClientTelemetry{}, repoerr.ErrNotFound
}
func (repo *repository) AddSubscription(ctx context.Context, sub journal.ClientSubscription) error {
q := `INSERT INTO subscriptions (id, subscriber_id, channel_id, subtopic, client_id)
VALUES (:id, :subscriber_id, :channel_id, :subtopic, :client_id);
`
result, err := repo.db.NamedExecContext(ctx, q, sub)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (repo *repository) CountSubscriptions(ctx context.Context, clientID string) (uint64, error) {
q := `SELECT COUNT(*) FROM subscriptions WHERE client_id = :client_id;`
sb := journal.ClientSubscription{
ClientID: clientID,
}
total, err := postgres.Total(ctx, repo.db, q, sb)
if err != nil {
return 0, postgres.HandleError(repoerr.ErrViewEntity, err)
}
return total, nil
}
func (repo *repository) RemoveSubscription(ctx context.Context, subscriberID string) error {
q := `DELETE FROM subscriptions WHERE subscriber_id = :subscriber_id;`
sb := journal.ClientSubscription{
SubscriberID: subscriberID,
}
_, err := repo.db.NamedExecContext(ctx, q, sb)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
return nil
}
func (repo *repository) IncrementInboundMessages(ctx context.Context, clientID string) error {
q := `
UPDATE clients_telemetry
SET inbound_messages = inbound_messages + 1,
last_seen = :last_seen
WHERE client_id = :client_id;
`
ct := journal.ClientTelemetry{
ClientID: clientID,
LastSeen: time.Now(),
}
dbct, err := toDBClientsTelemetry(ct)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
result, err := repo.db.NamedExecContext(ctx, q, dbct)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (repo *repository) IncrementOutboundMessages(ctx context.Context, channelID, subtopic string) error {
query := `
SELECT client_id, COUNT(*) AS match_count
FROM subscriptions
WHERE channel_id = :channel_id AND subtopic = :subtopic
GROUP BY client_id
`
sb := journal.ClientSubscription{
ChannelID: channelID,
Subtopic: subtopic,
}
rows, err := repo.db.NamedQueryContext(ctx, query, sb)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
defer rows.Close()
tx, err := repo.db.BeginTxx(ctx, nil)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
q := `UPDATE clients_telemetry
SET outbound_messages = outbound_messages + $1
WHERE client_id = $2;
`
for rows.Next() {
var clientID string
var count uint64
if err = rows.Scan(&clientID, &count); err != nil {
if err := tx.Rollback(); err != nil {
return errors.Wrap(errors.ErrRollbackTx, err)
}
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
if _, err = repo.db.ExecContext(ctx, q, count, clientID); err != nil {
if err := tx.Rollback(); err != nil {
return errors.Wrap(errors.ErrRollbackTx, err)
}
return errors.Wrap(errors.ErrRollbackTx, err)
}
}
if err = tx.Commit(); err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
return nil
}
type dbClientTelemetry struct {
ClientID string `db:"client_id"`
DomainID string `db:"domain_id"`
Subscriptions pgtype.TextArray `db:"subscriptions"`
InboundMessages uint64 `db:"inbound_messages"`
OutboundMessages uint64 `db:"outbound_messages"`
FirstSeen time.Time `db:"first_seen"`
LastSeen sql.NullTime `db:"last_seen"`
ClientID string `db:"client_id"`
DomainID string `db:"domain_id"`
InboundMessages uint64 `db:"inbound_messages"`
OutboundMessages uint64 `db:"outbound_messages"`
FirstSeen time.Time `db:"first_seen"`
LastSeen sql.NullTime `db:"last_seen"`
}
func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error) {
@@ -104,7 +232,6 @@ func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error)
return dbClientTelemetry{
ClientID: ct.ClientID,
DomainID: ct.DomainID,
Subscriptions: subs,
InboundMessages: ct.InboundMessages,
OutboundMessages: ct.OutboundMessages,
FirstSeen: ct.FirstSeen,
@@ -113,11 +240,6 @@ func toDBClientsTelemetry(ct journal.ClientTelemetry) (dbClientTelemetry, error)
}
func toClientsTelemetry(dbct dbClientTelemetry) (journal.ClientTelemetry, error) {
var subs []string
for _, e := range dbct.Subscriptions.Elements {
subs = append(subs, e.String)
}
var lastSeen time.Time
if dbct.LastSeen.Valid {
lastSeen = dbct.LastSeen.Time
@@ -126,7 +248,6 @@ func toClientsTelemetry(dbct dbClientTelemetry) (journal.ClientTelemetry, error)
return journal.ClientTelemetry{
ClientID: dbct.ClientID,
DomainID: dbct.DomainID,
Subscriptions: subs,
InboundMessages: dbct.InboundMessages,
OutboundMessages: dbct.OutboundMessages,
FirstSeen: dbct.FirstSeen,
+308 -1
View File
@@ -5,6 +5,9 @@ package journal
import (
"context"
"fmt"
"strings"
"time"
"github.com/absmach/supermq"
smqauthn "github.com/absmach/supermq/pkg/authn"
@@ -12,6 +15,16 @@ import (
svcerr "github.com/absmach/supermq/pkg/errors/service"
)
const (
clientCreate = "client.create"
clientRemove = "client.remove"
mqttSubscribe = "mqtt.client_subscribe"
mqttDisconnect = "mqtt.client_disconnect"
messagingPublish = "messaging.client_publish"
messagingSubscribe = "messaging.client_subscribe"
messagingUnsubscribe = "messaging.client_unsubscribe"
)
type service struct {
idProvider supermq.IDProvider
repository Repository
@@ -31,7 +44,14 @@ func (svc *service) Save(ctx context.Context, journal Journal) error {
}
journal.ID = id
return svc.repository.Save(ctx, journal)
if err := svc.repository.Save(ctx, journal); err != nil {
return err
}
if err := svc.handleTelemetry(ctx, journal); err != nil {
return err
}
return nil
}
func (svc *service) RetrieveAll(ctx context.Context, session smqauthn.Session, page Page) (JournalsPage, error) {
@@ -49,5 +69,292 @@ func (svc *service) RetrieveClientTelemetry(ctx context.Context, session smqauth
return ClientTelemetry{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
subs, err := svc.repository.CountSubscriptions(ctx, clientID)
if err != nil {
return ClientTelemetry{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
ct.Subscriptions = subs
return ct, nil
}
func (svc *service) handleTelemetry(ctx context.Context, journal Journal) error {
switch journal.Operation {
case clientCreate:
return svc.addClientTelemetry(ctx, journal)
case clientRemove:
return svc.removeClientTelemetry(ctx, journal)
case mqttSubscribe:
return svc.addMqttSubscription(ctx, journal)
case messagingSubscribe:
return svc.addSubscription(ctx, journal)
case messagingUnsubscribe:
return svc.removeSubscription(ctx, journal)
case messagingPublish:
return svc.updateMessageCount(ctx, journal)
case mqttDisconnect:
return svc.removeMqttSubscription(ctx, journal)
default:
return nil
}
}
func (svc *service) addClientTelemetry(ctx context.Context, journal Journal) error {
ce, err := toClientEvent(journal)
if err != nil {
return err
}
ct := ClientTelemetry{
ClientID: ce.id,
DomainID: ce.domain,
FirstSeen: ce.createdAt,
LastSeen: ce.createdAt,
}
return svc.repository.SaveClientTelemetry(ctx, ct)
}
func (svc *service) removeClientTelemetry(ctx context.Context, journal Journal) error {
ce, err := toClientEvent(journal)
if err != nil {
return err
}
return svc.repository.DeleteClientTelemetry(ctx, ce.id, ce.domain)
}
func (svc *service) addSubscription(ctx context.Context, journal Journal) error {
ae, err := toSubscribeEvent(journal)
if err != nil {
return err
}
var subtopic string
topics := strings.Split(ae.topic, ".")
if len(topics) > 2 {
subtopic = topics[2]
}
id, err := svc.idProvider.ID()
if err != nil {
return err
}
sub := ClientSubscription{
ID: id,
SubscriberID: ae.subscriberID,
ChannelID: topics[1],
Subtopic: subtopic,
ClientID: ae.clientID,
}
return svc.repository.AddSubscription(ctx, sub)
}
func (svc *service) addMqttSubscription(ctx context.Context, journal Journal) error {
ae, err := toMqttSubscribeEvent(journal)
if err != nil {
return err
}
id, err := svc.idProvider.ID()
if err != nil {
return err
}
sub := ClientSubscription{
ID: id,
SubscriberID: ae.subscriberID,
ChannelID: ae.channelID,
Subtopic: ae.subtopic,
ClientID: ae.clientID,
}
return svc.repository.AddSubscription(ctx, sub)
}
func (svc *service) removeSubscription(ctx context.Context, journal Journal) error {
ae, err := toUnsubscribeEvent(journal)
if err != nil {
return err
}
return svc.repository.RemoveSubscription(ctx, ae.subscriberID)
}
func (svc *service) removeMqttSubscription(ctx context.Context, journal Journal) error {
ae, err := toMqttDisconnectEvent(journal)
if err != nil {
return err
}
return svc.repository.RemoveSubscription(ctx, ae.subscriberID)
}
func (svc *service) updateMessageCount(ctx context.Context, journal Journal) error {
ae, err := toPublishEvent(journal)
if err != nil {
return err
}
if err := svc.repository.IncrementInboundMessages(ctx, ae.clientID); err != nil {
return err
}
if err := svc.repository.IncrementOutboundMessages(ctx, ae.channelID, ae.subtopic); err != nil {
return err
}
return nil
}
type clientEvent struct {
id string
domain string
createdAt time.Time
}
func toClientEvent(journal Journal) (clientEvent, error) {
var createdAt time.Time
var err error
id, err := getStringAttribute(journal, "id")
if err != nil {
return clientEvent{}, err
}
domain, err := getStringAttribute(journal, "domain")
if err != nil {
return clientEvent{}, err
}
createdAtStr := journal.Attributes["created_at"].(string)
if createdAtStr != "" {
createdAt, err = time.Parse(time.RFC3339, createdAtStr)
if err != nil {
return clientEvent{}, fmt.Errorf("invalid created_at format")
}
}
return clientEvent{
id: id,
domain: domain,
createdAt: createdAt,
}, nil
}
type adapterEvent struct {
clientID string
channelID string
subscriberID string
topic string
subtopic string
}
func toPublishEvent(journal Journal) (adapterEvent, error) {
clientID, err := getStringAttribute(journal, "client_id")
if err != nil {
return adapterEvent{}, err
}
channelID, err := getStringAttribute(journal, "channel_id")
if err != nil {
return adapterEvent{}, err
}
subtopic, err := getStringAttribute(journal, "subtopic")
if err != nil {
return adapterEvent{}, err
}
return adapterEvent{
clientID: clientID,
channelID: channelID,
subtopic: subtopic,
}, nil
}
func toSubscribeEvent(journal Journal) (adapterEvent, error) {
subscriberID, err := getStringAttribute(journal, "subscriber_id")
if err != nil {
return adapterEvent{}, err
}
topic, err := getStringAttribute(journal, "topic")
if err != nil {
return adapterEvent{}, err
}
var clientID string
clientID, err = getStringAttribute(journal, "client_id")
if err != nil {
clientID = ""
}
return adapterEvent{
clientID: clientID,
subscriberID: subscriberID,
topic: topic,
}, nil
}
func toUnsubscribeEvent(journal Journal) (adapterEvent, error) {
subscriberID, err := getStringAttribute(journal, "subscriber_id")
if err != nil {
return adapterEvent{}, err
}
topic, err := getStringAttribute(journal, "topic")
if err != nil {
return adapterEvent{}, err
}
return adapterEvent{
subscriberID: subscriberID,
topic: topic,
}, nil
}
func toMqttSubscribeEvent(journal Journal) (adapterEvent, error) {
clientID, err := getStringAttribute(journal, "client_id")
if err != nil {
return adapterEvent{}, err
}
subscriberID, err := getStringAttribute(journal, "subscriber_id")
if err != nil {
return adapterEvent{}, err
}
channelID, err := getStringAttribute(journal, "channel_id")
if err != nil {
return adapterEvent{}, err
}
subtopic, err := getStringAttribute(journal, "subtopic")
if err != nil {
return adapterEvent{}, err
}
return adapterEvent{
clientID: clientID,
subscriberID: subscriberID,
channelID: channelID,
subtopic: subtopic,
}, nil
}
func toMqttDisconnectEvent(journal Journal) (adapterEvent, error) {
subscriberID, err := getStringAttribute(journal, "subscriber_id")
if err != nil {
return adapterEvent{}, err
}
clientID, err := getStringAttribute(journal, "client_id")
if err != nil {
return adapterEvent{}, err
}
return adapterEvent{
subscriberID: subscriberID,
channelID: clientID,
}, nil
}
func getStringAttribute(journal Journal, key string) (string, error) {
value, ok := journal.Attributes[key].(string)
if !ok {
return "", fmt.Errorf("missing or invalid %s attribute", key)
}
return value, nil
}
+123 -26
View File
@@ -5,73 +5,170 @@ package events
import (
"context"
"net/url"
"regexp"
"strings"
"github.com/absmach/mgate/pkg/session"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/events"
"github.com/absmach/supermq/pkg/events/store"
)
const streamID = "supermq.mqtt"
//go:generate mockery --name EventStore --output=../mocks --filename events.go --quiet --note "Copyright (c) Abstract Machines"
type EventStore interface {
Connect(ctx context.Context, clientID, subscriberID string) error
Disconnect(ctx context.Context, clientID, subscriberID string) error
Subscribe(ctx context.Context, clientID, channelID, subscriberID, subtopic string) error
}
var (
errFailedSession = errors.New("failed to obtain session from context")
errMalformedTopic = errors.New("malformed topic")
channelRegExp = regexp.MustCompile(`^\/?channels\/([\w\-]+)\/messages(\/[^?]*)?(\?.*)?$`)
)
// EventStore is a struct used to store event streams in Redis.
type eventStore struct {
ep events.Publisher
handler session.Handler
instance string
}
// NewEventStore returns wrapper around mProxy service that sends
// NewEventStoreMiddleware returns middleware around mGate service that sends
// events to event store.
func NewEventStore(ctx context.Context, url, instance string) (EventStore, error) {
func NewEventStoreMiddleware(ctx context.Context, handler session.Handler, url, instance string) (session.Handler, error) {
publisher, err := store.NewPublisher(ctx, url, streamID)
if err != nil {
return nil, err
}
return &eventStore{
instance: instance,
ep: publisher,
handler: handler,
instance: instance,
}, nil
}
// Connect issues event on MQTT CONNECT.
func (es *eventStore) Connect(ctx context.Context, clientID, subscriberID string) error {
func (es *eventStore) AuthConnect(ctx context.Context) error {
if err := es.handler.AuthConnect(ctx); err != nil {
return err
}
s, ok := session.FromContext(ctx)
if !ok {
return errFailedSession
}
ev := connectEvent{
clientID: clientID,
operation: clientConnect,
subscriberID: subscriberID,
clientID: s.Username,
subscriberID: s.ID,
instance: es.instance,
}
return es.ep.Publish(ctx, ev)
}
// Disconnect issues event on MQTT CONNECT.
func (es *eventStore) Disconnect(ctx context.Context, clientID, subscriberID string) error {
func (es *eventStore) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error {
return es.handler.AuthPublish(ctx, topic, payload)
}
func (es *eventStore) AuthSubscribe(ctx context.Context, topics *[]string) error {
return es.handler.AuthSubscribe(ctx, topics)
}
func (es *eventStore) Connect(ctx context.Context) error {
return es.handler.Connect(ctx)
}
func (es *eventStore) Publish(ctx context.Context, topic *string, payload *[]byte) error {
return es.handler.Publish(ctx, topic, payload)
}
func (es *eventStore) Subscribe(ctx context.Context, topics *[]string) error {
if err := es.handler.Subscribe(ctx, topics); err != nil {
return err
}
s, ok := session.FromContext(ctx)
if !ok {
return errFailedSession
}
for _, topic := range *topics {
channelID, subtopic, err := parseTopic(topic)
if err != nil {
return err
}
ev := subscribeEvent{
operation: clientSubscribe,
clientID: s.Username,
channelID: channelID,
subscriberID: s.ID,
subtopic: subtopic,
}
if err := es.ep.Publish(ctx, ev); err != nil {
return err
}
}
return nil
}
func (es *eventStore) Unsubscribe(ctx context.Context, topics *[]string) error {
return es.handler.Unsubscribe(ctx, topics)
}
func (es *eventStore) Disconnect(ctx context.Context) error {
if err := es.handler.Disconnect(ctx); err != nil {
return err
}
s, ok := session.FromContext(ctx)
if !ok {
return errFailedSession
}
ev := connectEvent{
clientID: clientID,
operation: clientDisconnect,
subscriberID: subscriberID,
clientID: s.Username,
subscriberID: s.ID,
instance: es.instance,
}
return es.ep.Publish(ctx, ev)
}
// Subscribe issues event on MQTT SUBSCRIBE.
func (es *eventStore) Subscribe(ctx context.Context, clientID, channelID, subscriberID, subtopic string) error {
ev := subscribeEvent{
operation: clientSubscribe,
clientID: clientID,
channelID: channelID,
subscriberID: subscriberID,
subtopic: subtopic,
func parseTopic(topic string) (string, string, error) {
channelParts := channelRegExp.FindStringSubmatch(topic)
if len(channelParts) < 2 {
return "", "", errMalformedTopic
}
return es.ep.Publish(ctx, ev)
chanID := channelParts[1]
subtopic := channelParts[2]
if subtopic == "" {
return subtopic, chanID, nil
}
subtopic, err := url.QueryUnescape(subtopic)
if err != nil {
return "", "", errMalformedTopic
}
subtopic = strings.ReplaceAll(subtopic, "/", ".")
elems := strings.Split(subtopic, ".")
filteredElems := []string{}
for _, elem := range elems {
if elem == "" {
continue
}
if len(elem) > 1 && (strings.Contains(elem, "*") || strings.Contains(elem, ">")) {
return "", "", errMalformedTopic
}
filteredElems = append(filteredElems, elem)
}
subtopic = strings.Join(filteredElems, ".")
return chanID, subtopic, nil
}
+4 -39
View File
@@ -15,7 +15,6 @@ import (
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
"github.com/absmach/supermq/mqtt/events"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -67,13 +66,11 @@ type handler struct {
clients grpcClientsV1.ClientsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
logger *slog.Logger
es events.EventStore
}
// NewHandler creates new Handler entity.
func NewHandler(publisher messaging.Publisher, es events.EventStore, logger *slog.Logger, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) session.Handler {
func NewHandler(publisher messaging.Publisher, logger *slog.Logger, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) session.Handler {
return &handler{
es: es,
logger: logger,
publisher: publisher,
clients: clients,
@@ -107,10 +104,6 @@ func (h *handler) AuthConnect(ctx context.Context) error {
return errInvalidUserId
}
if err := h.es.Connect(ctx, s.Username, s.ID); err != nil {
h.logger.Error(errors.Wrap(ErrFailedPublishConnectEvent, err).Error())
}
return nil
}
@@ -203,18 +196,8 @@ func (h *handler) Subscribe(ctx context.Context, topics *[]string) error {
if !ok {
return errors.Wrap(ErrFailedSubscribe, ErrClientNotInitialized)
}
for _, topic := range *topics {
channelID, subTopic, err := parseTopic(topic)
if err != nil {
return err
}
if err := h.es.Subscribe(ctx, s.Username, channelID, s.ID, subTopic); err != nil {
return errors.Wrap(ErrFailedSubscribeEvent, err)
}
}
h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ",")))
return nil
}
@@ -225,6 +208,7 @@ func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
return errors.Wrap(ErrFailedUnsubscribe, ErrClientNotInitialized)
}
h.logger.Info(fmt.Sprintf(LogInfoUnsubscribed, s.ID, strings.Join(*topics, ",")))
return nil
}
@@ -235,9 +219,7 @@ func (h *handler) Disconnect(ctx context.Context) error {
return errors.Wrap(ErrFailedDisconnect, ErrClientNotInitialized)
}
h.logger.Error(fmt.Sprintf(LogInfoDisconnected, s.ID, s.Password))
if err := h.es.Disconnect(ctx, s.Username, s.ID); err != nil {
return errors.Wrap(ErrFailedPublishDisconnectEvent, err)
}
return nil
}
@@ -272,23 +254,6 @@ func (h *handler) authAccess(ctx context.Context, clientID, topic string, msgTyp
return nil
}
func parseTopic(topic string) (string, string, error) {
channelParts := channelRegExp.FindStringSubmatch(topic)
if len(channelParts) < 2 {
return "", "", errors.Wrap(ErrFailedPublish, ErrMalformedTopic)
}
chanID := channelParts[1]
subtopic := channelParts[2]
subtopic, err := parseSubtopic(subtopic)
if err != nil {
return "", "", errors.Wrap(ErrFailedParseSubtopic, err)
}
return chanID, subtopic, nil
}
func parseSubtopic(subtopic string) (string, error) {
if subtopic == "" {
return subtopic, nil
+3 -11
View File
@@ -68,9 +68,8 @@ var (
)
var (
clients = new(climocks.ClientsServiceClient)
channels = new(chmocks.ChannelsServiceClient)
eventStore = new(mocks.EventStore)
clients = new(climocks.ClientsServiceClient)
channels = new(chmocks.ChannelsServiceClient)
)
func TestAuthConnect(t *testing.T) {
@@ -147,10 +146,8 @@ func TestAuthConnect(t *testing.T) {
password = string(tc.session.Password)
}
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: password}).Return(tc.authNRes, tc.authNErr)
svcCall := eventStore.On("Connect", mock.Anything, clientID, mock.Anything).Return(tc.err)
err := handler.AuthConnect(ctx)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
clientsCall.Unset()
})
}
@@ -445,11 +442,9 @@ func TestSubscribe(t *testing.T) {
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
eventsCall := eventStore.On("Subscribe", mock.Anything, clientID, chanID, clientID, mock.Anything).Return(nil)
err := handler.Subscribe(ctx, &tc.topic)
assert.Contains(t, logBuffer.String(), tc.logMsg)
assert.Equal(t, tc.err, err)
eventsCall.Unset()
}
}
@@ -519,11 +514,9 @@ func TestDisconnect(t *testing.T) {
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
svcCall := eventStore.On("Disconnect", mock.Anything, clientID, mock.Anything).Return(tc.err)
err := handler.Disconnect(ctx)
assert.Contains(t, logBuffer.String(), tc.logMsg)
assert.Equal(t, tc.err, err)
svcCall.Unset()
}
}
@@ -534,6 +527,5 @@ func newHandler() session.Handler {
}
clients = new(climocks.ClientsServiceClient)
channels = new(chmocks.ChannelsServiceClient)
eventStore = new(mocks.EventStore)
return mqtt.NewHandler(mocks.NewPublisher(), eventStore, logger, clients, channels)
return mqtt.NewHandler(mocks.NewPublisher(), logger, clients, channels)
}
-84
View File
@@ -1,84 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
mock "github.com/stretchr/testify/mock"
)
// EventStore is an autogenerated mock type for the EventStore type
type EventStore struct {
mock.Mock
}
// Connect provides a mock function with given fields: ctx, clientID, subscriberID
func (_m *EventStore) Connect(ctx context.Context, clientID string, subscriberID string) error {
ret := _m.Called(ctx, clientID, subscriberID)
if len(ret) == 0 {
panic("no return value specified for Connect")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, clientID, subscriberID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Disconnect provides a mock function with given fields: ctx, clientID, subscriberID
func (_m *EventStore) Disconnect(ctx context.Context, clientID string, subscriberID string) error {
ret := _m.Called(ctx, clientID, subscriberID)
if len(ret) == 0 {
panic("no return value specified for Disconnect")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, clientID, subscriberID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Subscribe provides a mock function with given fields: ctx, clientID, channelID, subscriberID, subtopic
func (_m *EventStore) Subscribe(ctx context.Context, clientID string, channelID string, subscriberID string, subtopic string) error {
ret := _m.Called(ctx, clientID, channelID, subscriberID, subtopic)
if len(ret) == 0 {
panic("no return value specified for Subscribe")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok {
r0 = rf(ctx, clientID, channelID, subscriberID, subtopic)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewEventStore creates a new instance of EventStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewEventStore(t interface {
mock.TestingT
Cleanup(func())
}) *EventStore {
mock := &EventStore{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+4 -2
View File
@@ -35,13 +35,15 @@ func (pe publishEvent) Encode() (map[string]interface{}, error) {
type subscribeEvent struct {
operation string
subscriberID string
subtopic string
clientID string
topic string
}
func (se subscribeEvent) Encode() (map[string]interface{}, error) {
return map[string]interface{}{
"operation": se.operation,
"subscriber_id": se.subscriberID,
"subtopic": se.subtopic,
"client_id": se.clientID,
"topic": se.topic,
}, nil
}
+3 -2
View File
@@ -54,7 +54,8 @@ func (es *pubsubES) Subscribe(ctx context.Context, cfg messaging.SubscriberConfi
se := subscribeEvent{
operation: clientSubscribe,
subscriberID: cfg.ID,
subtopic: cfg.Topic,
clientID: cfg.ClientID,
topic: cfg.Topic,
}
return es.ep.Publish(ctx, se)
@@ -68,7 +69,7 @@ func (es *pubsubES) Unsubscribe(ctx context.Context, id string, topic string) er
se := subscribeEvent{
operation: clientUnsubscribe,
subscriberID: id,
subtopic: topic,
topic: topic,
}
return es.ep.Publish(ctx, se)
+1
View File
@@ -36,6 +36,7 @@ type MessageHandler interface {
type SubscriberConfig struct {
ID string
ClientID string
Topic string
Handler MessageHandler
DeliveryPolicy DeliveryPolicy
+4 -3
View File
@@ -75,9 +75,10 @@ func (svc *adapterService) Subscribe(ctx context.Context, clientKey, chanID, sub
}
subCfg := messaging.SubscriberConfig{
ID: clientID,
Topic: subject,
Handler: c,
ID: clientID,
ClientID: clientID,
Topic: subject,
Handler: c,
}
if err := svc.pubsub.Subscribe(ctx, subCfg); err != nil {
return ErrFailedSubscription
+4 -3
View File
@@ -158,9 +158,10 @@ func TestSubscribe(t *testing.T) {
for _, tc := range cases {
subConfig := messaging.SubscriberConfig{
ID: clientID,
Topic: "channels." + tc.chanID + "." + subTopic,
Handler: c,
ID: clientID,
Topic: "channels." + tc.chanID + "." + subTopic,
ClientID: clientID,
Handler: c,
}
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: tc.clientKey}).Return(tc.authNRes, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{