Sync with SMQ

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>
This commit is contained in:
Dusan Borovcanin
2025-04-11 15:31:55 +02:00
parent ffad1181cf
commit faaf42941d
8 changed files with 78 additions and 69 deletions
+10 -12
View File
@@ -144,14 +144,13 @@ func (bs bootstrapService) Add(ctx context.Context, session smqauthn.Session, to
if err != nil {
return Config{}, errors.Wrap(errCheckChannels, err)
}
cfg.Channels, err = bs.connectionChannels(toConnect, bs.toIDList(existing), session.DomainID, token)
cfg.Channels, err = bs.connectionChannels(ctx, toConnect, bs.toIDList(existing), session.DomainID, token)
if err != nil {
return Config{}, errors.Wrap(errConnectionChannels, err)
}
id := cfg.ClientID
mgClient, err := bs.client(session.DomainID, id, token)
mgClient, err := bs.client(ctx, session.DomainID, id, token)
if err != nil {
return Config{}, errors.Wrap(errClientNotFound, err)
}
@@ -223,7 +222,7 @@ func (bs bootstrapService) UpdateConnections(ctx context.Context, session smqaut
return errors.Wrap(errUpdateConnections, err)
}
channels, err := bs.connectionChannels(connections, bs.toIDList(existing), session.DomainID, token)
channels, err := bs.connectionChannels(ctx, connections, bs.toIDList(existing), session.DomainID, token)
if err != nil {
return errors.Wrap(errUpdateConnections, err)
}
@@ -336,7 +335,7 @@ func (bs bootstrapService) ChangeState(ctx context.Context, session smqauthn.Ses
switch state {
case Active:
for _, c := range cfg.Channels {
if err := bs.sdk.ConnectClients(context.Background(), c.ID, []string{cfg.ClientID}, []string{"Publish", "Subscribe"}, session.DomainID, token); err != nil {
if err := bs.sdk.ConnectClients(ctx, c.ID, []string{cfg.ClientID}, []string{"Publish", "Subscribe"}, session.DomainID, token); err != nil {
// Ignore conflict errors as they indicate the connection already exists.
if errors.Contains(err, svcerr.ErrConflict) {
continue
@@ -346,7 +345,7 @@ func (bs bootstrapService) ChangeState(ctx context.Context, session smqauthn.Ses
}
case Inactive:
for _, c := range cfg.Channels {
if err := bs.sdk.DisconnectClients(context.Background(), c.ID, []string{cfg.ClientID}, []string{"Publish", "Subscribe"}, session.DomainID, token); err != nil {
if err := bs.sdk.DisconnectClients(ctx, c.ID, []string{cfg.ClientID}, []string{"Publish", "Subscribe"}, session.DomainID, token); err != nil {
if errors.Contains(err, repoerr.ErrNotFound) {
continue
}
@@ -396,29 +395,28 @@ func (bs bootstrapService) DisconnectClientHandler(ctx context.Context, channelI
}
// Method client retrieves SuperMQ Client creating one if an empty ID is passed.
func (bs bootstrapService) client(domainID, id, token string) (mgsdk.Client, error) {
func (bs bootstrapService) client(ctx context.Context, domainID, id, token string) (mgsdk.Client, error) {
// If Client ID is not provided, then create new client.
if id == "" {
id, err := bs.idProvider.ID()
if err != nil {
return mgsdk.Client{}, errors.Wrap(errCreateClient, err)
}
client, sdkErr := bs.sdk.CreateClient(context.Background(), mgsdk.Client{ID: id, Name: "Bootstrapped Client " + id}, domainID, token)
client, sdkErr := bs.sdk.CreateClient(ctx, mgsdk.Client{ID: id, Name: "Bootstrapped Client " + id}, domainID, token)
if sdkErr != nil {
return mgsdk.Client{}, errors.Wrap(errCreateClient, sdkErr)
}
return client, nil
}
// If Client ID is provided, then retrieve client
client, sdkErr := bs.sdk.Client(context.Background(), id, domainID, token)
client, sdkErr := bs.sdk.Client(ctx, id, domainID, token)
if sdkErr != nil {
return mgsdk.Client{}, errors.Wrap(ErrClients, sdkErr)
}
return client, nil
}
func (bs bootstrapService) connectionChannels(channels, existing []string, domainID, token string) ([]Channel, error) {
func (bs bootstrapService) connectionChannels(ctx context.Context, channels, existing []string, domainID, token string) ([]Channel, error) {
add := make(map[string]bool, len(channels))
for _, ch := range channels {
add[ch] = true
@@ -432,7 +430,7 @@ func (bs bootstrapService) connectionChannels(channels, existing []string, domai
var ret []Channel
for id := range add {
ch, err := bs.sdk.Channel(context.Background(), id, domainID, token)
ch, err := bs.sdk.Channel(ctx, id, domainID, token)
if err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}