MG-1955 - Update Bootstrap service access control (#2199)

Signed-off-by: JeffMboya <jangina.mboya@gmail.com>
This commit is contained in:
JMboya
2024-07-09 14:03:10 +03:00
committed by GitHub
parent c0b9017679
commit b49a2cd012
15 changed files with 2346 additions and 1151 deletions
+10
View File
@@ -100,6 +100,8 @@ paths:
description: Missing or invalid config.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: Config does not exist.
"422":
@@ -126,6 +128,8 @@ paths:
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: Config does not exist.
"415":
@@ -151,6 +155,8 @@ paths:
description: Failed due to malformed config ID.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"422":
description: Database can't process request.
"500":
@@ -176,6 +182,8 @@ paths:
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: Config does not exist.
"415":
@@ -204,6 +212,8 @@ paths:
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: Config does not exist.
"415":
+66 -100
View File
@@ -328,25 +328,20 @@ func TestView(t *testing.T) {
defer bs.Close()
c := newConfig()
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
saved, err := svc.Add(context.Background(), validToken, c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
svcCall.Unset()
var channels []channel
for _, ch := range saved.Channels {
for _, ch := range c.Channels {
channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata})
}
data := config{
ThingID: saved.ThingID,
ThingKey: saved.ThingKey,
State: saved.State,
ThingID: c.ThingID,
ThingKey: c.ThingKey,
State: c.State,
Channels: channels,
ExternalID: saved.ExternalID,
ExternalKey: saved.ExternalKey,
Name: saved.Name,
Content: saved.Content,
ExternalID: c.ExternalID,
ExternalKey: c.ExternalKey,
Name: c.Name,
Content: c.Content,
}
cases := []struct {
@@ -360,7 +355,7 @@ func TestView(t *testing.T) {
{
desc: "view a config with invalid token",
auth: invalidToken,
id: saved.ThingID,
id: c.ThingID,
status: http.StatusUnauthorized,
res: config{},
err: svcerr.ErrAuthentication,
@@ -368,7 +363,7 @@ func TestView(t *testing.T) {
{
desc: "view a config",
auth: validToken,
id: saved.ThingID,
id: c.ThingID,
status: http.StatusOK,
res: data,
err: nil,
@@ -384,15 +379,23 @@ func TestView(t *testing.T) {
{
desc: "view a config with an empty token",
auth: "",
id: saved.ThingID,
id: c.ThingID,
status: http.StatusUnauthorized,
res: config{},
err: svcerr.ErrAuthentication,
},
{
desc: "view config without authorization",
auth: validToken,
id: c.ThingID,
status: http.StatusForbidden,
res: config{},
err: svcerr.ErrAuthorization,
},
}
for _, tc := range cases {
svcCall := svc.On("View", mock.Anything, mock.Anything, mock.Anything).Return(c, tc.err)
svcCall := svc.On("View", mock.Anything, tc.auth, tc.id).Return(c, tc.err)
req := testRequest{
client: bs.Client(),
method: http.MethodGet,
@@ -422,11 +425,6 @@ func TestUpdate(t *testing.T) {
defer bs.Close()
c := newConfig()
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
saved, err := svc.Add(context.Background(), validToken, c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
svcCall.Unset()
data := toJSON(updateReq)
cases := []struct {
@@ -441,7 +439,7 @@ func TestUpdate(t *testing.T) {
{
desc: "update with invalid token",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: invalidToken,
contentType: contentType,
status: http.StatusUnauthorized,
@@ -450,7 +448,7 @@ func TestUpdate(t *testing.T) {
{
desc: "update with an empty token",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: "",
contentType: contentType,
status: http.StatusUnauthorized,
@@ -459,7 +457,7 @@ func TestUpdate(t *testing.T) {
{
desc: "update a valid config",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: contentType,
status: http.StatusOK,
@@ -468,7 +466,7 @@ func TestUpdate(t *testing.T) {
{
desc: "update a config with wrong content type",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: "",
status: http.StatusUnsupportedMediaType,
@@ -486,7 +484,7 @@ func TestUpdate(t *testing.T) {
{
desc: "update a config with invalid request format",
req: "}",
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: contentType,
status: http.StatusBadRequest,
@@ -494,7 +492,7 @@ func TestUpdate(t *testing.T) {
},
{
desc: "update a config with an empty request",
id: saved.ThingID,
id: c.ThingID,
req: "",
auth: validToken,
contentType: contentType,
@@ -504,7 +502,7 @@ func TestUpdate(t *testing.T) {
}
for _, tc := range cases {
svcCall := svcCall.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
svcCall := svc.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
req := testRequest{
client: bs.Client(),
method: http.MethodPut,
@@ -525,11 +523,6 @@ func TestUpdateCert(t *testing.T) {
defer bs.Close()
c := newConfig()
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
saved, err := svc.Add(context.Background(), validToken, c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
svcCall.Unset()
data := toJSON(updateReq)
cases := []struct {
@@ -544,7 +537,7 @@ func TestUpdateCert(t *testing.T) {
{
desc: "update with invalid token",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: invalidToken,
contentType: contentType,
status: http.StatusUnauthorized,
@@ -553,7 +546,7 @@ func TestUpdateCert(t *testing.T) {
{
desc: "update with an empty token",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: "",
contentType: contentType,
status: http.StatusUnauthorized,
@@ -562,7 +555,7 @@ func TestUpdateCert(t *testing.T) {
{
desc: "update a valid config",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: contentType,
status: http.StatusOK,
@@ -571,7 +564,7 @@ func TestUpdateCert(t *testing.T) {
{
desc: "update a config with wrong content type",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: "",
status: http.StatusUnsupportedMediaType,
@@ -589,7 +582,7 @@ func TestUpdateCert(t *testing.T) {
{
desc: "update a config with invalid request format",
req: "}",
id: saved.ThingKey,
id: c.ThingKey,
auth: validToken,
contentType: contentType,
status: http.StatusBadRequest,
@@ -597,7 +590,7 @@ func TestUpdateCert(t *testing.T) {
},
{
desc: "update a config with an empty request",
id: saved.ThingID,
id: c.ThingID,
req: "",
auth: validToken,
contentType: contentType,
@@ -627,12 +620,6 @@ func TestUpdateConnections(t *testing.T) {
bs, svc := newBootstrapServer()
defer bs.Close()
c := newConfig()
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
saved, err := svc.Add(context.Background(), validToken, c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
svcCall.Unset()
data := toJSON(updateReq)
invalidChannels := updateReq
@@ -652,7 +639,7 @@ func TestUpdateConnections(t *testing.T) {
{
desc: "update connections with invalid token",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: invalidToken,
contentType: contentType,
status: http.StatusUnauthorized,
@@ -661,7 +648,7 @@ func TestUpdateConnections(t *testing.T) {
{
desc: "update connections with an empty token",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: "",
contentType: contentType,
status: http.StatusUnauthorized,
@@ -670,7 +657,7 @@ func TestUpdateConnections(t *testing.T) {
{
desc: "update connections valid config",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: contentType,
status: http.StatusOK,
@@ -679,7 +666,7 @@ func TestUpdateConnections(t *testing.T) {
{
desc: "update connections with wrong content type",
req: data,
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: "",
status: http.StatusUnsupportedMediaType,
@@ -697,7 +684,7 @@ func TestUpdateConnections(t *testing.T) {
{
desc: "update connections with invalid channels",
req: wrongData,
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: contentType,
status: http.StatusNotFound,
@@ -706,7 +693,7 @@ func TestUpdateConnections(t *testing.T) {
{
desc: "update a config with invalid request format",
req: "}",
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
contentType: contentType,
status: http.StatusBadRequest,
@@ -714,7 +701,7 @@ func TestUpdateConnections(t *testing.T) {
},
{
desc: "update a config with an empty request",
id: saved.ThingID,
id: c.ThingID,
req: "",
auth: validToken,
contentType: contentType,
@@ -724,7 +711,7 @@ func TestUpdateConnections(t *testing.T) {
}
for _, tc := range cases {
repoCall := svcCall.On("UpdateConnections", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
repoCall := svc.On("UpdateConnections", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
req := testRequest{
client: bs.Client(),
method: http.MethodPut,
@@ -758,28 +745,22 @@ func TestList(t *testing.T) {
c.Name = fmt.Sprintf("%s-%d", addName, i)
c.ExternalKey = fmt.Sprintf("%s%s", addExternalKey, strconv.Itoa(i))
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
saved, err := svc.Add(context.Background(), validToken, c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
svcCall.Unset()
var channels []channel
for _, ch := range saved.Channels {
for _, ch := range c.Channels {
channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata})
}
s := config{
ThingID: saved.ThingID,
ThingKey: saved.ThingKey,
ThingID: c.ThingID,
ThingKey: c.ThingKey,
Channels: channels,
ExternalID: saved.ExternalID,
ExternalKey: saved.ExternalKey,
Name: saved.Name,
Content: saved.Content,
State: saved.State,
ExternalID: c.ExternalID,
ExternalKey: c.ExternalKey,
Name: c.Name,
Content: c.Content,
State: c.State,
}
list[i] = s
}
// Change state of first 20 elements for filtering tests.
for i := 0; i < changedStateNum; i++ {
state := bootstrap.Active
@@ -1020,11 +1001,6 @@ func TestRemove(t *testing.T) {
defer bs.Close()
c := newConfig()
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
saved, err := svc.Add(context.Background(), validToken, c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
svcCall.Unset()
cases := []struct {
desc string
id string
@@ -1034,14 +1010,14 @@ func TestRemove(t *testing.T) {
}{
{
desc: "remove with invalid token",
id: saved.ThingID,
id: c.ThingID,
auth: invalidToken,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "remove with an empty token",
id: saved.ThingID,
id: c.ThingID,
auth: "",
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
@@ -1055,7 +1031,7 @@ func TestRemove(t *testing.T) {
},
{
desc: "remove config",
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
status: http.StatusNoContent,
err: nil,
@@ -1089,16 +1065,11 @@ func TestBootstrap(t *testing.T) {
defer bs.Close()
c := newConfig()
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
saved, err := svc.Add(context.Background(), validToken, c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
svcCall.Unset()
encExternKey, err := enc([]byte(c.ExternalKey))
assert.Nil(t, err, fmt.Sprintf("Encrypting config expected to succeed: %s.\n", err))
var channels []channel
for _, ch := range saved.Channels {
for _, ch := range c.Channels {
channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata})
}
@@ -1111,13 +1082,13 @@ func TestBootstrap(t *testing.T) {
ClientKey string `json:"client_key"`
CACert string `json:"ca_cert"`
}{
ThingID: saved.ThingID,
ThingKey: saved.ThingKey,
ThingID: c.ThingID,
ThingKey: c.ThingKey,
Channels: channels,
Content: saved.Content,
ClientCert: saved.ClientCert,
ClientKey: saved.ClientKey,
CACert: saved.CACert,
Content: c.Content,
ClientCert: c.ClientCert,
ClientKey: c.ClientKey,
CACert: c.CACert,
}
data := toJSON(s)
@@ -1225,11 +1196,6 @@ func TestChangeState(t *testing.T) {
defer bs.Close()
c := newConfig()
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
saved, err := svc.Add(context.Background(), validToken, c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
svcCall.Unset()
inactive := fmt.Sprintf("{\"state\": %d}", bootstrap.Inactive)
active := fmt.Sprintf("{\"state\": %d}", bootstrap.Active)
@@ -1244,7 +1210,7 @@ func TestChangeState(t *testing.T) {
}{
{
desc: "change state with invalid token",
id: saved.ThingID,
id: c.ThingID,
auth: invalidToken,
state: active,
contentType: contentType,
@@ -1253,7 +1219,7 @@ func TestChangeState(t *testing.T) {
},
{
desc: "change state with an empty token",
id: saved.ThingID,
id: c.ThingID,
auth: "",
state: active,
contentType: contentType,
@@ -1262,7 +1228,7 @@ func TestChangeState(t *testing.T) {
},
{
desc: "change state with invalid content type",
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
state: active,
contentType: "",
@@ -1271,7 +1237,7 @@ func TestChangeState(t *testing.T) {
},
{
desc: "change state to active",
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
state: active,
contentType: contentType,
@@ -1280,7 +1246,7 @@ func TestChangeState(t *testing.T) {
},
{
desc: "change state to inactive",
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
state: inactive,
contentType: contentType,
@@ -1298,7 +1264,7 @@ func TestChangeState(t *testing.T) {
},
{
desc: "change state to invalid value",
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
state: fmt.Sprintf("{\"state\": %d}", -3),
contentType: contentType,
@@ -1307,7 +1273,7 @@ func TestChangeState(t *testing.T) {
},
{
desc: "change state with invalid data",
id: saved.ThingID,
id: c.ThingID,
auth: validToken,
state: "",
contentType: contentType,
+10
View File
@@ -36,6 +36,16 @@ func (req addReq) validate() error {
return apiutil.ErrBearerKey
}
if len(req.Channels) == 0 {
return apiutil.ErrEmptyList
}
for _, channel := range req.Channels {
if channel == "" {
return apiutil.ErrMissingID
}
}
return nil
}
+50 -1
View File
@@ -8,23 +8,39 @@ import (
"testing"
"github.com/absmach/magistrala/bootstrap"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/apiutil"
"github.com/stretchr/testify/assert"
)
var (
channel1 = testsutil.GenerateUUID(&testing.T{})
channel2 = testsutil.GenerateUUID(&testing.T{})
)
func TestAddReqValidation(t *testing.T) {
cases := []struct {
desc string
token string
externalID string
externalKey string
channels []string
err error
}{
{
desc: "empty key",
desc: "valid request",
token: "token",
externalID: "external-id",
externalKey: "external-key",
channels: []string{channel1, channel2},
err: nil,
},
{
desc: "empty token",
token: "",
externalID: "external-id",
externalKey: "external-key",
channels: []string{channel1, channel2},
err: apiutil.ErrBearerToken,
},
{
@@ -32,6 +48,7 @@ func TestAddReqValidation(t *testing.T) {
token: "token",
externalID: "",
externalKey: "external-key",
channels: []string{channel1, channel2},
err: apiutil.ErrMissingID,
},
{
@@ -39,8 +56,33 @@ func TestAddReqValidation(t *testing.T) {
token: "token",
externalID: "external-id",
externalKey: "",
channels: []string{channel1, channel2},
err: apiutil.ErrBearerKey,
},
{
desc: "empty external key and external ID",
token: "token",
externalID: "",
externalKey: "",
channels: []string{channel1, channel2},
err: apiutil.ErrMissingID,
},
{
desc: "empty channels",
token: "token",
externalID: "external-id",
externalKey: "external-key",
channels: []string{},
err: apiutil.ErrEmptyList,
},
{
desc: "empty channel value",
token: "token",
externalID: "external-id",
externalKey: "external-key",
channels: []string{channel1, ""},
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
@@ -48,6 +90,7 @@ func TestAddReqValidation(t *testing.T) {
token: tc.token,
ExternalID: tc.externalID,
ExternalKey: tc.externalKey,
Channels: tc.channels,
}
err := req.validate()
@@ -93,6 +136,12 @@ func TestUpdateReqValidation(t *testing.T) {
id string
err error
}{
{
desc: "valid request",
token: "token",
id: "id",
err: nil,
},
{
desc: "empty token",
token: "",
+10 -10
View File
@@ -17,7 +17,7 @@ import (
// MGChannels is a list of Magistrala Channels corresponding Magistrala Thing connects to.
type Config struct {
ThingID string `json:"thing_id"`
Owner string `json:"owner,omitempty"`
DomainID string `json:"domain_id,omitempty"`
Name string `json:"name,omitempty"`
ClientCert string `json:"client_cert,omitempty"`
ClientKey string `json:"client_key,omitempty"`
@@ -35,7 +35,7 @@ type Channel struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Owner string `json:"owner_id"`
DomainID string `json:"domain_id"`
Parent string `json:"parent_id,omitempty"`
Description string `json:"description,omitempty"`
CreatedAt time.Time `json:"created_at"`
@@ -69,11 +69,11 @@ type ConfigRepository interface {
// RetrieveByID retrieves the Config having the provided identifier, that is owned
// by the specified user.
RetrieveByID(ctx context.Context, owner, id string) (Config, error)
RetrieveByID(ctx context.Context, domainID, id string) (Config, error)
// RetrieveAll retrieves a subset of Configs that are owned
// by the specific user, with given filter parameters.
RetrieveAll(ctx context.Context, owner string, filter Filter, offset, limit uint64) ConfigsPage
RetrieveAll(ctx context.Context, domainID string, thingIDs []string, filter Filter, offset, limit uint64) ConfigsPage
// RetrieveByExternalID returns Config for given external ID.
RetrieveByExternalID(ctx context.Context, externalID string) (Config, error)
@@ -82,23 +82,23 @@ type ConfigRepository interface {
// to indicate operation failure.
Update(ctx context.Context, cfg Config) error
// UpdateCerts updates and returns an existing Config certificate and owner.
// UpdateCerts updates and returns an existing Config certificate and domainID.
// A non-nil error is returned to indicate operation failure.
UpdateCert(ctx context.Context, owner, thingID, clientCert, clientKey, caCert string) (Config, error)
UpdateCert(ctx context.Context, domainID, thingID, clientCert, clientKey, caCert string) (Config, error)
// UpdateConnections updates a list of Channels the Config is connected to
// adding new Channels if needed.
UpdateConnections(ctx context.Context, owner, id string, channels []Channel, connections []string) error
UpdateConnections(ctx context.Context, domainID, id string, channels []Channel, connections []string) error
// Remove removes the Config having the provided identifier, that is owned
// by the specified user.
Remove(ctx context.Context, owner, id string) error
Remove(ctx context.Context, domainID, id string) error
// ChangeState changes of the Config, that is owned by the specific user.
ChangeState(ctx context.Context, owner, id string, state State) error
ChangeState(ctx context.Context, domainID, id string, state State) error
// ListExisting retrieves those channels from the given list that exist in DB.
ListExisting(ctx context.Context, owner string, ids []string) ([]Channel, error)
ListExisting(ctx context.Context, domainID string, ids []string) ([]Channel, error)
// Methods RemoveThing, UpdateChannel, and RemoveChannel are related to
// event sourcing. That's why these methods surpass ownership check.
+4 -4
View File
@@ -58,8 +58,8 @@ func (ce configEvent) Encode() (map[string]interface{}, error) {
if ce.Content != "" {
val["content"] = ce.Content
}
if ce.Owner != "" {
val["owner"] = ce.Owner
if ce.DomainID != "" {
val["domain_id "] = ce.DomainID
}
if ce.Name != "" {
val["name"] = ce.Name
@@ -143,8 +143,8 @@ func (be bootstrapEvent) Encode() (map[string]interface{}, error) {
if be.Content != "" {
val["content"] = be.Content
}
if be.Owner != "" {
val["owner"] = be.Owner
if be.DomainID != "" {
val["domain_id "] = be.DomainID
}
if be.Name != "" {
val["name"] = be.Name
File diff suppressed because it is too large Load Diff
+35 -35
View File
@@ -17,9 +17,9 @@ type ConfigRepository struct {
mock.Mock
}
// ChangeState provides a mock function with given fields: ctx, owner, id, state
func (_m *ConfigRepository) ChangeState(ctx context.Context, owner string, id string, state bootstrap.State) error {
ret := _m.Called(ctx, owner, id, state)
// ChangeState provides a mock function with given fields: ctx, domainID, id, state
func (_m *ConfigRepository) ChangeState(ctx context.Context, domainID string, id string, state bootstrap.State) error {
ret := _m.Called(ctx, domainID, id, state)
if len(ret) == 0 {
panic("no return value specified for ChangeState")
@@ -27,7 +27,7 @@ func (_m *ConfigRepository) ChangeState(ctx context.Context, owner string, id st
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, bootstrap.State) error); ok {
r0 = rf(ctx, owner, id, state)
r0 = rf(ctx, domainID, id, state)
} else {
r0 = ret.Error(0)
}
@@ -71,9 +71,9 @@ func (_m *ConfigRepository) DisconnectThing(ctx context.Context, channelID strin
return r0
}
// ListExisting provides a mock function with given fields: ctx, owner, ids
func (_m *ConfigRepository) ListExisting(ctx context.Context, owner string, ids []string) ([]bootstrap.Channel, error) {
ret := _m.Called(ctx, owner, ids)
// ListExisting provides a mock function with given fields: ctx, domainID, ids
func (_m *ConfigRepository) ListExisting(ctx context.Context, domainID string, ids []string) ([]bootstrap.Channel, error) {
ret := _m.Called(ctx, domainID, ids)
if len(ret) == 0 {
panic("no return value specified for ListExisting")
@@ -82,10 +82,10 @@ func (_m *ConfigRepository) ListExisting(ctx context.Context, owner string, ids
var r0 []bootstrap.Channel
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, []string) ([]bootstrap.Channel, error)); ok {
return rf(ctx, owner, ids)
return rf(ctx, domainID, ids)
}
if rf, ok := ret.Get(0).(func(context.Context, string, []string) []bootstrap.Channel); ok {
r0 = rf(ctx, owner, ids)
r0 = rf(ctx, domainID, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]bootstrap.Channel)
@@ -93,7 +93,7 @@ func (_m *ConfigRepository) ListExisting(ctx context.Context, owner string, ids
}
if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok {
r1 = rf(ctx, owner, ids)
r1 = rf(ctx, domainID, ids)
} else {
r1 = ret.Error(1)
}
@@ -101,9 +101,9 @@ func (_m *ConfigRepository) ListExisting(ctx context.Context, owner string, ids
return r0, r1
}
// Remove provides a mock function with given fields: ctx, owner, id
func (_m *ConfigRepository) Remove(ctx context.Context, owner string, id string) error {
ret := _m.Called(ctx, owner, id)
// Remove provides a mock function with given fields: ctx, domainID, id
func (_m *ConfigRepository) Remove(ctx context.Context, domainID string, id string) error {
ret := _m.Called(ctx, domainID, id)
if len(ret) == 0 {
panic("no return value specified for Remove")
@@ -111,7 +111,7 @@ func (_m *ConfigRepository) Remove(ctx context.Context, owner string, id string)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, owner, id)
r0 = rf(ctx, domainID, id)
} else {
r0 = ret.Error(0)
}
@@ -155,17 +155,17 @@ func (_m *ConfigRepository) RemoveThing(ctx context.Context, id string) error {
return r0
}
// RetrieveAll provides a mock function with given fields: ctx, owner, filter, offset, limit
func (_m *ConfigRepository) RetrieveAll(ctx context.Context, owner string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage {
ret := _m.Called(ctx, owner, filter, offset, limit)
// RetrieveAll provides a mock function with given fields: ctx, domainID, thingIDs, filter, offset, limit
func (_m *ConfigRepository) RetrieveAll(ctx context.Context, domainID string, thingIDs []string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage {
ret := _m.Called(ctx, domainID, thingIDs, filter, offset, limit)
if len(ret) == 0 {
panic("no return value specified for RetrieveAll")
}
var r0 bootstrap.ConfigsPage
if rf, ok := ret.Get(0).(func(context.Context, string, bootstrap.Filter, uint64, uint64) bootstrap.ConfigsPage); ok {
r0 = rf(ctx, owner, filter, offset, limit)
if rf, ok := ret.Get(0).(func(context.Context, string, []string, bootstrap.Filter, uint64, uint64) bootstrap.ConfigsPage); ok {
r0 = rf(ctx, domainID, thingIDs, filter, offset, limit)
} else {
r0 = ret.Get(0).(bootstrap.ConfigsPage)
}
@@ -201,9 +201,9 @@ func (_m *ConfigRepository) RetrieveByExternalID(ctx context.Context, externalID
return r0, r1
}
// RetrieveByID provides a mock function with given fields: ctx, owner, id
func (_m *ConfigRepository) RetrieveByID(ctx context.Context, owner string, id string) (bootstrap.Config, error) {
ret := _m.Called(ctx, owner, id)
// RetrieveByID provides a mock function with given fields: ctx, domainID, id
func (_m *ConfigRepository) RetrieveByID(ctx context.Context, domainID string, id string) (bootstrap.Config, error) {
ret := _m.Called(ctx, domainID, id)
if len(ret) == 0 {
panic("no return value specified for RetrieveByID")
@@ -212,16 +212,16 @@ func (_m *ConfigRepository) RetrieveByID(ctx context.Context, owner string, id s
var r0 bootstrap.Config
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (bootstrap.Config, error)); ok {
return rf(ctx, owner, id)
return rf(ctx, domainID, id)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) bootstrap.Config); ok {
r0 = rf(ctx, owner, id)
r0 = rf(ctx, domainID, id)
} else {
r0 = ret.Get(0).(bootstrap.Config)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, owner, id)
r1 = rf(ctx, domainID, id)
} else {
r1 = ret.Error(1)
}
@@ -275,9 +275,9 @@ func (_m *ConfigRepository) Update(ctx context.Context, cfg bootstrap.Config) er
return r0
}
// UpdateCert provides a mock function with given fields: ctx, owner, thingID, clientCert, clientKey, caCert
func (_m *ConfigRepository) UpdateCert(ctx context.Context, owner string, thingID string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error) {
ret := _m.Called(ctx, owner, thingID, clientCert, clientKey, caCert)
// UpdateCert provides a mock function with given fields: ctx, domainID, thingID, clientCert, clientKey, caCert
func (_m *ConfigRepository) UpdateCert(ctx context.Context, domainID string, thingID string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error) {
ret := _m.Called(ctx, domainID, thingID, clientCert, clientKey, caCert)
if len(ret) == 0 {
panic("no return value specified for UpdateCert")
@@ -286,16 +286,16 @@ func (_m *ConfigRepository) UpdateCert(ctx context.Context, owner string, thingI
var r0 bootstrap.Config
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (bootstrap.Config, error)); ok {
return rf(ctx, owner, thingID, clientCert, clientKey, caCert)
return rf(ctx, domainID, thingID, clientCert, clientKey, caCert)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) bootstrap.Config); ok {
r0 = rf(ctx, owner, thingID, clientCert, clientKey, caCert)
r0 = rf(ctx, domainID, thingID, clientCert, clientKey, caCert)
} else {
r0 = ret.Get(0).(bootstrap.Config)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) error); ok {
r1 = rf(ctx, owner, thingID, clientCert, clientKey, caCert)
r1 = rf(ctx, domainID, thingID, clientCert, clientKey, caCert)
} else {
r1 = ret.Error(1)
}
@@ -321,9 +321,9 @@ func (_m *ConfigRepository) UpdateChannel(ctx context.Context, c bootstrap.Chann
return r0
}
// UpdateConnections provides a mock function with given fields: ctx, owner, id, channels, connections
func (_m *ConfigRepository) UpdateConnections(ctx context.Context, owner string, id string, channels []bootstrap.Channel, connections []string) error {
ret := _m.Called(ctx, owner, id, channels, connections)
// UpdateConnections provides a mock function with given fields: ctx, domainID, id, channels, connections
func (_m *ConfigRepository) UpdateConnections(ctx context.Context, domainID string, id string, channels []bootstrap.Channel, connections []string) error {
ret := _m.Called(ctx, domainID, id, channels, connections)
if len(ret) == 0 {
panic("no return value specified for UpdateConnections")
@@ -331,7 +331,7 @@ func (_m *ConfigRepository) UpdateConnections(ctx context.Context, owner string,
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, []bootstrap.Channel, []string) error); ok {
r0 = rf(ctx, owner, id, channels, connections)
r0 = rf(ctx, domainID, id, channels, connections)
} else {
r0 = ret.Error(0)
}
+107 -102
View File
@@ -49,8 +49,8 @@ func NewConfigRepository(db postgres.Database, log *slog.Logger) bootstrap.Confi
}
func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsConnIDs []string) (thingID string, err error) {
q := `INSERT INTO configs (magistrala_thing, owner, name, client_cert, client_key, ca_cert, magistrala_key, external_id, external_key, content, state)
VALUES (:magistrala_thing, :owner, :name, :client_cert, :client_key, :ca_cert, :magistrala_key, :external_id, :external_key, :content, :state)`
q := `INSERT INTO configs (magistrala_thing, domain_id, name, client_cert, client_key, ca_cert, magistrala_key, external_id, external_key, content, state)
VALUES (:magistrala_thing, :domain_id, :name, :client_cert, :client_key, :ca_cert, :magistrala_key, :external_id, :external_key, :content, :state)`
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
@@ -60,7 +60,7 @@ func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsCo
defer func() {
if err != nil {
err = cr.rollback(err, tx)
err = cr.rollback("Save method", err, tx)
}
}()
@@ -68,13 +68,13 @@ func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsCo
switch pgErr := err.(type) {
case *pgconn.PgError:
if pgErr.Code == pgerrcode.UniqueViolation {
return "", repoerr.ErrConflict
err = repoerr.ErrConflict
}
}
return "", err
}
if err := insertChannels(ctx, cfg.Owner, cfg.Channels, tx); err != nil {
if err := insertChannels(cfg.DomainID, cfg.Channels, tx); err != nil {
return "", errors.Wrap(errSaveChannels, err)
}
@@ -82,20 +82,21 @@ func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsCo
return "", errors.Wrap(errSaveConnections, err)
}
if err := tx.Commit(); err != nil {
return "", err
if commitErr := tx.Commit(); commitErr != nil {
return "", commitErr
}
return cfg.ThingID, nil
}
func (cr configRepository) RetrieveByID(ctx context.Context, owner, id string) (bootstrap.Config, error) {
func (cr configRepository) RetrieveByID(ctx context.Context, domainID, id string) (bootstrap.Config, error) {
q := `SELECT magistrala_thing, magistrala_key, external_id, external_key, name, content, state, client_cert, ca_cert
FROM configs
WHERE magistrala_thing = :magistrala_thing AND owner = :owner`
WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id`
dbcfg := dbConfig{
ThingID: id,
Owner: owner,
ThingID: id,
DomainID: domainID,
}
row, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
if err != nil {
@@ -116,8 +117,8 @@ func (cr configRepository) RetrieveByID(ctx context.Context, owner, id string) (
q = `SELECT magistrala_channel, name, metadata FROM channels ch
INNER JOIN connections conn
ON ch.magistrala_channel = conn.channel_id AND ch.owner = conn.config_owner
WHERE conn.config_id = :magistrala_thing AND conn.config_owner = :owner`
ON ch.magistrala_channel = conn.channel_id AND ch.domain_id = conn.domain_id
WHERE conn.config_id = :magistrala_thing AND conn.domain_id = :domain_id`
rows, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
if err != nil {
@@ -133,7 +134,7 @@ func (cr configRepository) RetrieveByID(ctx context.Context, owner, id string) (
cr.log.Error(fmt.Sprintf("Failed to read connected thing due to %s", err))
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
dbch.Owner = nullString(dbcfg.Owner)
dbch.DomainID = nullString(dbcfg.DomainID)
ch, err := toChannel(dbch)
if err != nil {
@@ -148,12 +149,12 @@ func (cr configRepository) RetrieveByID(ctx context.Context, owner, id string) (
return cfg, nil
}
func (cr configRepository) RetrieveAll(ctx context.Context, owner string, filter bootstrap.Filter, offset, limit uint64) bootstrap.ConfigsPage {
search, params := cr.retrieveAll(owner, filter)
func (cr configRepository) RetrieveAll(ctx context.Context, domainID string, thingIDs []string, filter bootstrap.Filter, offset, limit uint64) bootstrap.ConfigsPage {
search, params := buildRetrieveQueryParams(domainID, thingIDs, filter)
n := len(params)
q := `SELECT magistrala_thing, magistrala_key, external_id, external_key, name, content, state
FROM configs %s ORDER BY magistrala_thing LIMIT $%d OFFSET $%d`
FROM configs %s ORDER BY magistrala_thing LIMIT $%d OFFSET $%d`
q = fmt.Sprintf(q, search, n+1, n+2)
rows, err := cr.db.QueryContext(ctx, q, append(params, limit, offset)...)
@@ -167,7 +168,7 @@ func (cr configRepository) RetrieveAll(ctx context.Context, owner string, filter
configs := []bootstrap.Config{}
for rows.Next() {
c := bootstrap.Config{Owner: owner}
c := bootstrap.Config{DomainID: domainID}
if err := rows.Scan(&c.ThingID, &c.ThingKey, &c.ExternalID, &c.ExternalKey, &name, &content, &c.State); err != nil {
cr.log.Error(fmt.Sprintf("Failed to read retrieved config due to %s", err))
return bootstrap.ConfigsPage{}
@@ -195,7 +196,7 @@ func (cr configRepository) RetrieveAll(ctx context.Context, owner string, filter
}
func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID string) (bootstrap.Config, error) {
q := `SELECT magistrala_thing, magistrala_key, external_key, owner, name, client_cert, client_key, ca_cert, content, state
q := `SELECT magistrala_thing, magistrala_key, external_key, domain_id, name, client_cert, client_key, ca_cert, content, state
FROM configs
WHERE external_id = :external_id`
dbcfg := dbConfig{
@@ -219,9 +220,9 @@ func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID
}
q = `SELECT magistrala_channel, name, metadata FROM channels ch
INNER JOIN connections conn
ON ch.magistrala_channel = conn.channel_id AND ch.owner = conn.config_owner
WHERE conn.config_id = :magistrala_thing AND conn.config_owner = :owner`
INNER JOIN connections conn
ON ch.magistrala_channel = conn.channel_id AND ch.domain_id = conn.domain_id
WHERE conn.config_id = :magistrala_thing AND conn.domain_id = :domain_id`
rows, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
if err != nil {
@@ -254,13 +255,13 @@ func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID
}
func (cr configRepository) Update(ctx context.Context, cfg bootstrap.Config) error {
q := `UPDATE configs SET name = :name, content = :content WHERE magistrala_thing = :magistrala_thing AND owner = :owner `
q := `UPDATE configs SET name = :name, content = :content WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id `
dbcfg := dbConfig{
Name: nullString(cfg.Name),
Content: nullString(cfg.Content),
ThingID: cfg.ThingID,
Owner: cfg.Owner,
Name: nullString(cfg.Name),
Content: nullString(cfg.Content),
ThingID: cfg.ThingID,
DomainID: cfg.DomainID,
}
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
@@ -280,14 +281,14 @@ func (cr configRepository) Update(ctx context.Context, cfg bootstrap.Config) err
return nil
}
func (cr configRepository) UpdateCert(ctx context.Context, owner, thingID, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE magistrala_thing = :magistrala_thing AND owner = :owner
func (cr configRepository) UpdateCert(ctx context.Context, domainID, thingID, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id
RETURNING magistrala_thing, client_cert, client_key, ca_cert`
dbcfg := dbConfig{
ThingID: thingID,
ClientCert: nullString(clientCert),
Owner: owner,
DomainID: domainID,
ClientKey: nullString(clientKey),
CaCert: nullString(caCert),
}
@@ -309,7 +310,7 @@ func (cr configRepository) UpdateCert(ctx context.Context, owner, thingID, clien
return toConfig(dbcfg), nil
}
func (cr configRepository) UpdateConnections(ctx context.Context, owner, id string, channels []bootstrap.Channel, connections []string) error {
func (cr configRepository) UpdateConnections(ctx context.Context, domainID, id string, channels []bootstrap.Channel, connections []string) (err error) {
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
@@ -317,35 +318,37 @@ func (cr configRepository) UpdateConnections(ctx context.Context, owner, id stri
defer func() {
if err != nil {
err = cr.rollback(err, tx)
err = cr.rollback("UpdateConnections method", err, tx)
} else {
if commitErr := tx.Commit(); commitErr != nil {
err = commitErr
}
}
}()
if err := insertChannels(ctx, owner, channels, tx); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
if err = insertChannels(domainID, channels, tx); err != nil {
err = errors.Wrap(repoerr.ErrUpdateEntity, err)
return err
}
if err := updateConnections(ctx, owner, id, connections, tx); err != nil {
if err = updateConnections(domainID, id, connections, tx); err != nil {
if e, ok := err.(*pgconn.PgError); ok {
if e.Code == pgerrcode.ForeignKeyViolation {
return repoerr.ErrNotFound
err = repoerr.ErrNotFound
}
}
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := tx.Commit(); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
err = errors.Wrap(repoerr.ErrUpdateEntity, err)
return err
}
return nil
}
func (cr configRepository) Remove(ctx context.Context, owner, id string) error {
q := `DELETE FROM configs WHERE magistrala_thing = :magistrala_thing AND owner = :owner`
func (cr configRepository) Remove(ctx context.Context, domainID, id string) error {
q := `DELETE FROM configs WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id`
dbcfg := dbConfig{
ThingID: id,
Owner: owner,
ThingID: id,
DomainID: domainID,
}
if _, err := cr.db.NamedExecContext(ctx, q, dbcfg); err != nil {
@@ -359,13 +362,13 @@ func (cr configRepository) Remove(ctx context.Context, owner, id string) error {
return nil
}
func (cr configRepository) ChangeState(ctx context.Context, owner, id string, state bootstrap.State) error {
q := `UPDATE configs SET state = :state WHERE magistrala_thing = :magistrala_thing AND owner = :owner;`
func (cr configRepository) ChangeState(ctx context.Context, domainID, id string, state bootstrap.State) error {
q := `UPDATE configs SET state = :state WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id;`
dbcfg := dbConfig{
ThingID: id,
State: state,
Owner: owner,
ThingID: id,
State: state,
DomainID: domainID,
}
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
@@ -385,7 +388,7 @@ func (cr configRepository) ChangeState(ctx context.Context, owner, id string, st
return nil
}
func (cr configRepository) ListExisting(ctx context.Context, owner string, ids []string) ([]bootstrap.Channel, error) {
func (cr configRepository) ListExisting(ctx context.Context, domainID string, ids []string) ([]bootstrap.Channel, error) {
var channels []bootstrap.Channel
if len(ids) == 0 {
return channels, nil
@@ -396,8 +399,8 @@ func (cr configRepository) ListExisting(ctx context.Context, owner string, ids [
return []bootstrap.Channel{}, err
}
q := "SELECT magistrala_channel, name, metadata FROM channels WHERE owner = $1 AND magistrala_channel = ANY ($2)"
rows, err := cr.db.QueryxContext(ctx, q, owner, chans)
q := "SELECT magistrala_channel, name, metadata FROM channels WHERE domain_id = $1 AND magistrala_channel = ANY ($2)"
rows, err := cr.db.QueryxContext(ctx, q, domainID, chans)
if err != nil {
return []bootstrap.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
@@ -472,62 +475,66 @@ func (cr configRepository) ConnectThing(ctx context.Context, channelID, thingID
func (cr configRepository) DisconnectThing(ctx context.Context, channelID, thingID string) error {
q := `UPDATE configs SET state = $1 WHERE EXISTS (
SELECT 1 FROM connections WHERE config_id = $2 AND channel_id = $3)`
result, err := cr.db.ExecContext(ctx, q, bootstrap.Inactive, thingID, channelID)
_, err := cr.db.ExecContext(ctx, q, bootstrap.Inactive, thingID, channelID)
if err != nil {
return errors.Wrap(errDisconnectThing, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (cr configRepository) retrieveAll(owner string, filter bootstrap.Filter) (string, []interface{}) {
template := `WHERE owner = $1 %s`
params := []interface{}{owner}
// One empty string so that strings Join works if only one filter is applied.
queries := []string{""}
// Since owner is the first param, start from 2.
counter := 2
func buildRetrieveQueryParams(domainID string, thingIDs []string, filter bootstrap.Filter) (string, []interface{}) {
params := []interface{}{}
queries := []string{}
if len(thingIDs) != 0 {
queries = append(queries, fmt.Sprintf("magistrala_thing IN ('%s')", strings.Join(thingIDs, "','")))
} else if domainID != "" {
params = append(params, domainID)
queries = append(queries, fmt.Sprintf("domain_id = $%d", len(params)))
}
// Adjust the starting point for placeholders based on the current length of params
counter := len(params) + 1
for k, v := range filter.FullMatch {
queries = append(queries, fmt.Sprintf("%s = $%d", k, counter))
params = append(params, v)
queries = append(queries, fmt.Sprintf("%s = $%d", k, counter))
counter++
}
for k, v := range filter.PartialMatch {
queries = append(queries, fmt.Sprintf("LOWER(%s) LIKE '%%' || $%d || '%%'", k, counter))
params = append(params, v)
queries = append(queries, fmt.Sprintf("LOWER(%s) LIKE '%%' || $%d || '%%'", k, counter))
counter++
}
f := strings.Join(queries, " AND ")
return fmt.Sprintf(template, f), params
if len(queries) > 0 {
return "WHERE " + strings.Join(queries, " AND "), params
}
return "", params
}
func (cr configRepository) rollback(defErr error, tx *sqlx.Tx) error {
func (cr configRepository) rollback(content string, defErr error, tx *sqlx.Tx) error {
if err := tx.Rollback(); err != nil {
return errors.Wrap(defErr, errors.Wrap(errors.New("failed to rollback"), err))
return errors.Wrap(defErr, errors.Wrap(errors.New("failed to rollback at "+content), err))
}
return defErr
}
func insertChannels(_ context.Context, owner string, channels []bootstrap.Channel, tx *sqlx.Tx) error {
func insertChannels(domainID string, channels []bootstrap.Channel, tx *sqlx.Tx) error {
if len(channels) == 0 {
return nil
}
var chans []dbChannel
for _, ch := range channels {
dbch, err := toDBChannel(owner, ch)
dbch, err := toDBChannel(domainID, ch)
if err != nil {
return err
}
chans = append(chans, dbch)
}
q := `INSERT INTO channels (magistrala_channel, owner, name, metadata, parent_id, description, created_at, updated_at, updated_by, status)
VALUES (:magistrala_channel, :owner, :name, :metadata, :parent_id, :description, :created_at, :updated_at, :updated_by, :status)`
q := `INSERT INTO channels (magistrala_channel, domain_id, name, metadata, parent_id, description, created_at, updated_at, updated_by, status)
VALUES (:magistrala_channel, :domain_id, :name, :metadata, :parent_id, :description, :created_at, :updated_at, :updated_by, :status)`
if _, err := tx.NamedExec(q, chans); err != nil {
e := err
if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgerrcode.UniqueViolation {
@@ -544,15 +551,15 @@ func insertConnections(_ context.Context, cfg bootstrap.Config, connections []st
return nil
}
q := `INSERT INTO connections (config_id, channel_id, config_owner, channel_owner)
VALUES (:config_id, :channel_id, :config_owner, :channel_owner)`
q := `INSERT INTO connections (config_id, channel_id, domain_id)
VALUES (:config_id, :channel_id, :domain_id)`
conns := []dbConnection{}
for _, conn := range connections {
dbconn := dbConnection{
Config: cfg.ThingID,
Channel: conn,
ConfigOwner: cfg.Owner,
ChannelOwner: cfg.Owner,
Config: cfg.ThingID,
Channel: conn,
DomainID: cfg.DomainID,
}
conns = append(conns, dbconn)
}
@@ -561,13 +568,13 @@ func insertConnections(_ context.Context, cfg bootstrap.Config, connections []st
return err
}
func updateConnections(_ context.Context, owner, id string, connections []string, tx *sqlx.Tx) error {
func updateConnections(domainID, id string, connections []string, tx *sqlx.Tx) error {
if len(connections) == 0 {
return nil
}
q := `DELETE FROM connections
WHERE config_id = $1 AND config_owner = $2 AND channel_owner = $2
WHERE config_id = $1 AND domain_id = $2
AND channel_id NOT IN ($3)`
var conn pgtype.TextArray
@@ -575,7 +582,7 @@ func updateConnections(_ context.Context, owner, id string, connections []string
return err
}
res, err := tx.Exec(q, id, owner, conn)
res, err := tx.Exec(q, id, domainID, conn)
if err != nil {
return err
}
@@ -585,16 +592,15 @@ func updateConnections(_ context.Context, owner, id string, connections []string
return err
}
q = `INSERT INTO connections (config_id, channel_id, config_owner, channel_owner)
VALUES (:config_id, :channel_id, :config_owner, :channel_owner)`
q = `INSERT INTO connections (config_id, channel_id, domain_id)
VALUES (:config_id, :channel_id, :domain_id)`
conns := []dbConnection{}
for _, conn := range connections {
dbconn := dbConnection{
Config: id,
Channel: conn,
ConfigOwner: owner,
ChannelOwner: owner,
Config: id,
Channel: conn,
DomainID: domainID,
}
conns = append(conns, dbconn)
}
@@ -636,7 +642,7 @@ func nullTime(t time.Time) sql.NullTime {
type dbConfig struct {
ThingID string `db:"magistrala_thing"`
Owner string `db:"owner"`
DomainID string `db:"domain_id"`
Name sql.NullString `db:"name"`
ClientCert sql.NullString `db:"client_cert"`
ClientKey sql.NullString `db:"client_key"`
@@ -651,7 +657,7 @@ type dbConfig struct {
func toDBConfig(cfg bootstrap.Config) dbConfig {
return dbConfig{
ThingID: cfg.ThingID,
Owner: cfg.Owner,
DomainID: cfg.DomainID,
Name: nullString(cfg.Name),
ClientCert: nullString(cfg.ClientCert),
ClientKey: nullString(cfg.ClientKey),
@@ -667,7 +673,7 @@ func toDBConfig(cfg bootstrap.Config) dbConfig {
func toConfig(dbcfg dbConfig) bootstrap.Config {
cfg := bootstrap.Config{
ThingID: dbcfg.ThingID,
Owner: dbcfg.Owner,
DomainID: dbcfg.DomainID,
ThingKey: dbcfg.ThingKey,
ExternalID: dbcfg.ExternalID,
ExternalKey: dbcfg.ExternalKey,
@@ -699,7 +705,7 @@ func toConfig(dbcfg dbConfig) bootstrap.Config {
type dbChannel struct {
ID string `db:"magistrala_channel"`
Name sql.NullString `db:"name"`
Owner sql.NullString `db:"owner"`
DomainID sql.NullString `db:"domain_id"`
Metadata string `db:"metadata"`
Parent sql.NullString `db:"parent_id,omitempty"`
Description string `db:"description,omitempty"`
@@ -709,11 +715,11 @@ type dbChannel struct {
Status clients.Status `db:"status"`
}
func toDBChannel(owner string, ch bootstrap.Channel) (dbChannel, error) {
func toDBChannel(domainID string, ch bootstrap.Channel) (dbChannel, error) {
dbch := dbChannel{
ID: ch.ID,
Name: nullString(ch.Name),
Owner: nullString(owner),
DomainID: nullString(domainID),
Parent: nullString(ch.Parent),
Description: ch.Description,
CreatedAt: ch.CreatedAt,
@@ -742,8 +748,8 @@ func toChannel(dbch dbChannel) (bootstrap.Channel, error) {
if dbch.Name.Valid {
ch.Name = dbch.Name.String
}
if dbch.Owner.Valid {
ch.Owner = dbch.Owner.String
if dbch.DomainID.Valid {
ch.DomainID = dbch.DomainID.String
}
if dbch.Parent.Valid {
ch.Parent = dbch.Parent.String
@@ -763,8 +769,7 @@ func toChannel(dbch dbChannel) (bootstrap.Channel, error) {
}
type dbConnection struct {
Config string `db:"config_id"`
Channel string `db:"channel_id"`
ConfigOwner string `db:"config_owner"`
ChannelOwner string `db:"channel_owner"`
Config string `db:"config_id"`
Channel string `db:"channel_id"`
DomainID string `db:"domain_id"`
}
+175 -142
View File
@@ -27,7 +27,7 @@ var (
ThingKey: "mg-key",
ExternalID: "external-id",
ExternalKey: "external-key",
Owner: "user@email.com",
DomainID: testsutil.GenerateUUID(&testing.T{}),
Channels: []bootstrap.Channel{
{ID: "1", Name: "name 1", Metadata: map[string]interface{}{"meta": 1.0}},
{ID: "2", Name: "name 2", Metadata: map[string]interface{}{"meta": 2.0}},
@@ -121,38 +121,38 @@ func TestRetrieveByID(t *testing.T) {
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
cases := []struct {
desc string
owner string
id string
err error
desc string
domainID string
id string
err error
}{
{
desc: "retrieve config",
owner: c.Owner,
id: id,
err: nil,
desc: "retrieve config",
domainID: c.DomainID,
id: id,
err: nil,
},
{
desc: "retrieve config with wrong owner",
owner: "2",
id: id,
err: repoerr.ErrNotFound,
desc: "retrieve config with wrong domain ID ",
domainID: "2",
id: id,
err: repoerr.ErrNotFound,
},
{
desc: "retrieve a non-existing config",
owner: c.Owner,
id: nonexistentConfID.String(),
err: repoerr.ErrNotFound,
desc: "retrieve a non-existing config",
domainID: c.DomainID,
id: nonexistentConfID.String(),
err: repoerr.ErrNotFound,
},
{
desc: "retrieve a config with invalid ID",
owner: c.Owner,
id: "invalid",
err: repoerr.ErrNotFound,
desc: "retrieve a config with invalid ID",
domainID: c.DomainID,
id: "invalid",
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
_, err := repo.RetrieveByID(context.Background(), tc.owner, tc.id)
_, err := repo.RetrieveByID(context.Background(), tc.domainID, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
@@ -162,6 +162,8 @@ func TestRetrieveAll(t *testing.T) {
err := deleteChannels(context.Background(), repo)
require.Nil(t, err, "Channels cleanup expected to succeed.")
thingIDs := make([]string, numConfigs)
for i := 0; i < numConfigs; i++ {
c := config
@@ -173,6 +175,8 @@ func TestRetrieveAll(t *testing.T) {
c.ThingID = uid.String()
c.ThingKey = uid.String()
thingIDs[i] = c.ThingID
if i%2 == 0 {
c.State = bootstrap.Active
}
@@ -184,55 +188,85 @@ func TestRetrieveAll(t *testing.T) {
_, err = repo.Save(context.Background(), c, channels)
require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
}
cases := []struct {
desc string
owner string
offset uint64
limit uint64
filter bootstrap.Filter
size int
desc string
domainID string
thingID []string
offset uint64
limit uint64
filter bootstrap.Filter
size int
}{
{
desc: "retrieve all",
owner: config.Owner,
offset: 0,
limit: uint64(numConfigs),
size: numConfigs,
desc: "retrieve all configs",
domainID: config.DomainID,
thingID: []string{},
offset: 0,
limit: uint64(numConfigs),
size: numConfigs,
},
{
desc: "retrieve subset",
owner: config.Owner,
offset: 5,
limit: uint64(numConfigs - 5),
size: numConfigs - 5,
desc: "retrieve a subset of configs",
domainID: config.DomainID,
thingID: []string{},
offset: 5,
limit: uint64(numConfigs - 5),
size: numConfigs - 5,
},
{
desc: "retrieve wrong owner",
owner: "2",
offset: 0,
limit: uint64(numConfigs),
size: 0,
desc: "retrieve with wrong domain ID ",
domainID: "2",
thingID: []string{},
offset: 0,
limit: uint64(numConfigs),
size: 0,
},
{
desc: "retrieve all active",
owner: config.Owner,
offset: 0,
limit: uint64(numConfigs),
filter: bootstrap.Filter{FullMatch: map[string]string{"state": bootstrap.Active.String()}},
size: numConfigs / 2,
desc: "retrieve all active configs ",
domainID: config.DomainID,
thingID: []string{},
offset: 0,
limit: uint64(numConfigs),
filter: bootstrap.Filter{FullMatch: map[string]string{"state": bootstrap.Active.String()}},
size: numConfigs / 2,
},
{
desc: "retrieve search by name",
owner: config.Owner,
offset: 0,
limit: uint64(numConfigs),
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
size: 1,
desc: "retrieve all with partial match filter",
domainID: config.DomainID,
thingID: []string{},
offset: 0,
limit: uint64(numConfigs),
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
size: 1,
},
{
desc: "retrieve search by name",
domainID: config.DomainID,
thingID: []string{},
offset: 0,
limit: uint64(numConfigs),
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
size: 1,
},
{
desc: "retrieve by valid thingIDs",
domainID: config.DomainID,
thingID: thingIDs,
offset: 0,
limit: uint64(numConfigs),
size: 10,
},
{
desc: "retrieve by non-existing thingID",
domainID: config.DomainID,
thingID: []string{"non-existing"},
offset: 0,
limit: uint64(numConfigs),
size: 0,
},
}
for _, tc := range cases {
ret := repo.RetrieveAll(context.Background(), tc.owner, tc.filter, tc.offset, tc.limit)
ret := repo.RetrieveAll(context.Background(), tc.domainID, tc.thingID, tc.filter, tc.offset, tc.limit)
size := len(ret.Configs)
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
}
@@ -295,8 +329,8 @@ func TestUpdate(t *testing.T) {
c.Content = "new content"
c.Name = "new name"
wrongOwner := c
wrongOwner.Owner = "3"
wrongDomainID := c
wrongDomainID.DomainID = "3"
cases := []struct {
desc string
@@ -305,8 +339,8 @@ func TestUpdate(t *testing.T) {
err error
}{
{
desc: "update with wrong owner",
config: wrongOwner,
desc: "update with wrong domainID ",
config: wrongDomainID,
err: repoerr.ErrNotFound,
},
{
@@ -340,13 +374,13 @@ func TestUpdateCert(t *testing.T) {
c.Content = "new content"
c.Name = "new name"
wrongOwner := c
wrongOwner.Owner = "3"
wrongDomainID := c
wrongDomainID.DomainID = "3"
cases := []struct {
desc string
thingID string
owner string
domainID string
cert string
certKey string
ca string
@@ -354,34 +388,34 @@ func TestUpdateCert(t *testing.T) {
err error
}{
{
desc: "update with wrong owner",
desc: "update with wrong domain ID ",
thingID: "",
cert: "cert",
certKey: "certKey",
ca: "",
owner: "wrong",
domainID: wrongDomainID.DomainID,
expectedConfig: bootstrap.Config{},
err: repoerr.ErrNotFound,
},
{
desc: "update a config",
thingID: c.ThingID,
cert: "cert",
certKey: "certKey",
ca: "ca",
owner: c.Owner,
desc: "update a config",
thingID: c.ThingID,
cert: "cert",
certKey: "certKey",
ca: "ca",
domainID: c.DomainID,
expectedConfig: bootstrap.Config{
ThingID: c.ThingID,
ClientCert: "cert",
CACert: "ca",
ClientKey: "certKey",
Owner: c.Owner,
DomainID: c.DomainID,
},
err: nil,
},
}
for _, tc := range cases {
cfg, err := repo.UpdateCert(context.Background(), tc.owner, tc.thingID, tc.cert, tc.certKey, tc.ca)
cfg, err := repo.UpdateCert(context.Background(), tc.domainID, tc.thingID, tc.cert, tc.certKey, tc.ca)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.expectedConfig, cfg, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.expectedConfig, cfg))
}
@@ -415,7 +449,7 @@ func TestUpdateConnections(t *testing.T) {
cases := []struct {
desc string
owner string
domainID string
id string
channels []bootstrap.Channel
connections []string
@@ -423,7 +457,7 @@ func TestUpdateConnections(t *testing.T) {
}{
{
desc: "update connections of non-existing config",
owner: config.Owner,
domainID: config.DomainID,
id: "unknown",
channels: nil,
connections: []string{channels[1]},
@@ -431,7 +465,7 @@ func TestUpdateConnections(t *testing.T) {
},
{
desc: "update connections",
owner: config.Owner,
domainID: config.DomainID,
id: c.ThingID,
channels: nil,
connections: []string{channels[1]},
@@ -439,7 +473,7 @@ func TestUpdateConnections(t *testing.T) {
},
{
desc: "update connections with existing channels",
owner: config.Owner,
domainID: config.DomainID,
id: c2,
channels: nil,
connections: channels,
@@ -447,7 +481,7 @@ func TestUpdateConnections(t *testing.T) {
},
{
desc: "update connections no channels",
owner: config.Owner,
domainID: config.DomainID,
id: c.ThingID,
channels: nil,
connections: nil,
@@ -455,7 +489,7 @@ func TestUpdateConnections(t *testing.T) {
},
}
for _, tc := range cases {
err := repo.UpdateConnections(context.Background(), tc.owner, tc.id, tc.channels, tc.connections)
err := repo.UpdateConnections(context.Background(), tc.domainID, tc.id, tc.channels, tc.connections)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
@@ -479,10 +513,10 @@ func TestRemove(t *testing.T) {
// Removal works the same for both existing and non-existing
// (removed) config
for i := 0; i < 2; i++ {
err := repo.Remove(context.Background(), c.Owner, id)
err := repo.Remove(context.Background(), c.DomainID, id)
assert.Nil(t, err, fmt.Sprintf("%d: failed to remove config due to: %s", i, err))
_, err = repo.RetrieveByID(context.Background(), c.Owner, id)
_, err = repo.RetrieveByID(context.Background(), c.DomainID, id)
assert.True(t, errors.Contains(err, repoerr.ErrNotFound), fmt.Sprintf("%d: expected %s got %s", i, repoerr.ErrNotFound, err))
}
}
@@ -504,41 +538,41 @@ func TestChangeState(t *testing.T) {
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
cases := []struct {
desc string
owner string
id string
state bootstrap.State
err error
desc string
domainID string
id string
state bootstrap.State
err error
}{
{
desc: "change state with wrong owner",
id: saved,
owner: "2",
err: repoerr.ErrNotFound,
desc: "change state with wrong domain ID ",
id: saved,
domainID: "2",
err: repoerr.ErrNotFound,
},
{
desc: "change state with wrong id",
id: "wrong",
owner: c.Owner,
err: repoerr.ErrNotFound,
desc: "change state with wrong id",
id: "wrong",
domainID: c.DomainID,
err: repoerr.ErrNotFound,
},
{
desc: "change state to Active",
id: saved,
owner: c.Owner,
state: bootstrap.Active,
err: nil,
desc: "change state to Active",
id: saved,
domainID: c.DomainID,
state: bootstrap.Active,
err: nil,
},
{
desc: "change state to Inactive",
id: saved,
owner: c.Owner,
state: bootstrap.Inactive,
err: nil,
desc: "change state to Inactive",
id: saved,
domainID: c.DomainID,
state: bootstrap.Inactive,
err: nil,
},
}
for _, tc := range cases {
err := repo.ChangeState(context.Background(), tc.owner, tc.id, tc.state)
err := repo.ChangeState(context.Background(), tc.domainID, tc.id, tc.state)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
@@ -564,31 +598,31 @@ func TestListExisting(t *testing.T) {
cases := []struct {
desc string
owner string
domainID string
connections []string
existing []bootstrap.Channel
}{
{
desc: "list all existing channels",
owner: c.Owner,
domainID: c.DomainID,
connections: channels,
existing: chs,
},
{
desc: "list a subset of existing channels",
owner: c.Owner,
domainID: c.DomainID,
connections: []string{channels[0], "5"},
existing: []bootstrap.Channel{chs[0]},
},
{
desc: "list a subset of existing channels empty",
owner: c.Owner,
domainID: c.DomainID,
connections: []string{"5", "6"},
existing: []bootstrap.Channel{},
},
}
for _, tc := range cases {
existing, err := repo.ListExisting(context.Background(), tc.owner, tc.connections)
existing, err := repo.ListExisting(context.Background(), tc.domainID, tc.connections)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error: %s", tc.desc, err))
assert.ElementsMatch(t, tc.existing, existing, fmt.Sprintf("%s: Got non-matching elements.", tc.desc))
}
@@ -640,7 +674,7 @@ func TestUpdateChannel(t *testing.T) {
err = repo.UpdateChannel(context.Background(), update)
assert.Nil(t, err, fmt.Sprintf("updating config expected to succeed: %s.\n", err))
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
var retreved bootstrap.Channel
for _, c := range cfg.Channels {
@@ -649,7 +683,7 @@ func TestUpdateChannel(t *testing.T) {
break
}
}
update.Owner = retreved.Owner
update.DomainID = retreved.DomainID
assert.Equal(t, update, retreved, fmt.Sprintf("expected %s, go %s", update, retreved))
}
@@ -671,7 +705,7 @@ func TestRemoveChannel(t *testing.T) {
err = repo.RemoveChannel(context.Background(), c.Channels[0].ID)
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
assert.NotContains(t, cfg.Channels, c.Channels[0], fmt.Sprintf("expected to remove channel %s from %s", c.Channels[0], cfg.Channels))
}
@@ -704,38 +738,37 @@ func TestConnectThing(t *testing.T) {
emptyThing := c
emptyThing.ThingID = ""
emptyThing.ThingKey = ""
emptyThing.ExternalID = ""
emptyThing.ExternalKey = ""
emptyThing.Channels = []bootstrap.Channel{}
cases := []struct {
desc string
owner string
domainID string
id string
state bootstrap.State
channels []bootstrap.Channel
connections []string
err error
}{
{
desc: "connect disconnected thing",
owner: config.Owner,
domainID: c.DomainID,
id: saved,
state: bootstrap.Inactive,
channels: c.Channels,
connections: channels,
err: nil,
},
{
desc: "connect already connected thing",
owner: config.Owner,
domainID: c.DomainID,
id: connectedThing.ThingID,
state: connectedThing.State,
channels: c.Channels,
connections: channels,
err: nil,
},
{
desc: "connect non-existent thing",
owner: config.Owner,
domainID: c.DomainID,
id: wrongID,
channels: c.Channels,
connections: channels,
@@ -743,7 +776,7 @@ func TestConnectThing(t *testing.T) {
},
{
desc: "connect random thing",
owner: config.Owner,
domainID: c.DomainID,
id: randomThing.ThingID,
channels: c.Channels,
connections: channels,
@@ -751,7 +784,7 @@ func TestConnectThing(t *testing.T) {
},
{
desc: "connect empty thing",
owner: config.Owner,
domainID: c.DomainID,
id: emptyThing.ThingID,
channels: c.Channels,
connections: channels,
@@ -763,7 +796,7 @@ func TestConnectThing(t *testing.T) {
if i == 0 {
err = repo.ConnectThing(context.Background(), ch.ID, tc.id)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: Expected error: %s, got: %s.\n", tc.desc, tc.err, err))
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
assert.Equal(t, cfg.State, bootstrap.Active, fmt.Sprintf("expected to be active when a connection is added from %s", cfg))
} else {
@@ -771,7 +804,7 @@ func TestConnectThing(t *testing.T) {
}
}
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
assert.Equal(t, cfg.State, bootstrap.Active, fmt.Sprintf("expected to be active when a connection is added from %s", cfg))
}
@@ -805,57 +838,57 @@ func TestDisconnectThing(t *testing.T) {
emptyThing := c
emptyThing.ThingID = ""
emptyThing.ThingKey = ""
emptyThing.ExternalID = ""
emptyThing.ExternalKey = ""
cases := []struct {
desc string
owner string
domainID string
id string
state bootstrap.State
channels []bootstrap.Channel
connections []string
err error
}{
{
desc: "disconnect connected thing",
owner: config.Owner,
domainID: c.DomainID,
id: connectedThing.ThingID,
state: connectedThing.State,
channels: c.Channels,
connections: channels,
err: nil,
},
{
desc: "disconnect already disconnected thing",
owner: config.Owner,
domainID: c.DomainID,
id: saved,
state: bootstrap.Inactive,
channels: c.Channels,
connections: channels,
err: nil,
},
{
desc: "disconnect invalid thing",
owner: config.Owner,
domainID: c.DomainID,
id: wrongID,
channels: c.Channels,
connections: channels,
err: repoerr.ErrNotFound,
err: nil,
},
{
desc: "disconnect random thing",
owner: config.Owner,
domainID: c.DomainID,
id: randomThing.ThingID,
channels: c.Channels,
connections: channels,
err: repoerr.ErrNotFound,
err: nil,
},
{
desc: "disconnect empty thing",
owner: config.Owner,
domainID: c.DomainID,
id: emptyThing.ThingID,
channels: c.Channels,
connections: channels,
err: repoerr.ErrNotFound,
err: nil,
},
}
@@ -864,11 +897,11 @@ func TestDisconnectThing(t *testing.T) {
err = repo.DisconnectThing(context.Background(), ch.ID, tc.id)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: Expected error: %s, got: %s.\n", tc.desc, tc.err, err))
}
}
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
assert.Equal(t, cfg.State, bootstrap.Inactive, fmt.Sprintf("expected to be inactive when a connection is removed from %s", cfg))
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
assert.Equal(t, cfg.State, bootstrap.Inactive, fmt.Sprintf("expected to be inactive when a connection is removed from %s", cfg))
}
}
func deleteChannels(ctx context.Context, repo bootstrap.ConfigRepository) error {
+20
View File
@@ -83,6 +83,26 @@ func Migration() *migrate.MemoryMigrationSource {
`ALTER TABLE IF EXISTS channels RENAME COLUMN mainflux_channel TO magistrala_channel`,
},
},
{
Id: "configs_5",
Up: []string{
`ALTER TABLE IF EXISTS configs RENAME COLUMN owner TO domain_id`,
`ALTER TABLE IF EXISTS channels RENAME COLUMN owner TO domain_id`,
`ALTER TABLE IF EXISTS configs ADD CONSTRAINT configs_name_domain_id_key UNIQUE (name, domain_id)`,
},
},
{
Id: "configs_6",
Up: []string{
`ALTER TABLE IF EXISTS connections DROP CONSTRAINT IF EXISTS connections_pkey`,
`ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS channel_owner`,
`ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS config_owner`,
`ALTER TABLE IF EXISTS connections ADD COLUMN IF NOT EXISTS domain_id VARCHAR(256) NOT NULL`,
`ALTER TABLE IF EXISTS connections ADD CONSTRAINT connections_pkey PRIMARY KEY (channel_id, config_id, domain_id)`,
`ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (channel_id, domain_id) REFERENCES channels (magistrala_channel, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`,
`ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (config_id, domain_id) REFERENCES configs (magistrala_thing, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`,
},
},
},
}
}
+133 -25
View File
@@ -11,6 +11,7 @@ import (
"time"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/auth"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
@@ -34,6 +35,9 @@ var (
// ErrAddBootstrap indicates error in adding bootstrap configuration.
ErrAddBootstrap = errors.New("failed to add bootstrap configuration")
// ErrNotInSameDomain indicates entities are not in the same domain.
errNotInSameDomain = errors.New("entities are not in the same domain")
errUpdateConnections = errors.New("failed to update connections")
errRemoveBootstrap = errors.New("failed to remove bootstrap configuration")
errChangeState = errors.New("failed to change state of bootstrap configuration")
@@ -82,7 +86,7 @@ type Service interface {
// Bootstrap returns Config to the Thing with provided external ID using external key.
Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (Config, error)
// ChangeState changes state of the Thing with given ID and owner.
// ChangeState changes state of the Thing with given thing ID and domain ID.
ChangeState(ctx context.Context, token, id string, state State) error
// Methods RemoveConfig, UpdateChannel, and RemoveChannel are used as
@@ -123,26 +127,29 @@ type bootstrapService struct {
}
// New returns new Bootstrap service.
func New(auth magistrala.AuthServiceClient, configs ConfigRepository, sdk mgsdk.SDK, encKey []byte, idp magistrala.IDProvider) Service {
func New(uauth magistrala.AuthServiceClient, configs ConfigRepository, sdk mgsdk.SDK, encKey []byte, idp magistrala.IDProvider) Service {
return &bootstrapService{
configs: configs,
sdk: sdk,
auth: auth,
auth: uauth,
encKey: encKey,
idProvider: idp,
}
}
func (bs bootstrapService) Add(ctx context.Context, token string, cfg Config) (Config, error) {
owner, err := bs.identify(ctx, token)
user, err := bs.identify(ctx, token)
if err != nil {
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if _, err := bs.authorize(ctx, "", auth.UsersKind, user.GetId(), auth.MembershipPermission, auth.DomainType, user.GetDomainId()); err != nil {
return Config{}, err
}
toConnect := bs.toIDList(cfg.Channels)
// Check if channels exist. This is the way to prevent fetching channels that already exist.
existing, err := bs.configs.ListExisting(ctx, owner, toConnect)
existing, err := bs.configs.ListExisting(ctx, user.GetDomainId(), toConnect)
if err != nil {
return Config{}, errors.Wrap(errCheckChannels, err)
}
@@ -158,8 +165,14 @@ func (bs bootstrapService) Add(ctx context.Context, token string, cfg Config) (C
return Config{}, errors.Wrap(errThingNotFound, err)
}
for _, channel := range cfg.Channels {
if channel.DomainID != mgThing.DomainID {
return Config{}, errors.Wrap(svcerr.ErrMalformedEntity, errNotInSameDomain)
}
}
cfg.ThingID = mgThing.ID
cfg.Owner = owner
cfg.DomainID = user.GetDomainId()
cfg.State = Inactive
cfg.ThingKey = mgThing.Credentials.Secret
@@ -182,11 +195,14 @@ func (bs bootstrapService) Add(ctx context.Context, token string, cfg Config) (C
}
func (bs bootstrapService) View(ctx context.Context, token, id string) (Config, error) {
owner, err := bs.identify(ctx, token)
user, err := bs.identify(ctx, token)
if err != nil {
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
cfg, err := bs.configs.RetrieveByID(ctx, owner, id)
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.ViewPermission, auth.ThingType, id); err != nil {
return Config{}, err
}
cfg, err := bs.configs.RetrieveByID(ctx, user.GetDomainId(), id)
if err != nil {
return Config{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
@@ -194,12 +210,15 @@ func (bs bootstrapService) View(ctx context.Context, token, id string) (Config,
}
func (bs bootstrapService) Update(ctx context.Context, token string, cfg Config) error {
owner, err := bs.identify(ctx, token)
user, err := bs.identify(ctx, token)
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.EditPermission, auth.ThingType, cfg.ThingID); err != nil {
return err
}
cfg.Owner = owner
cfg.DomainID = user.GetDomainId()
if err = bs.configs.Update(ctx, cfg); err != nil {
return errors.Wrap(errUpdateConnections, err)
}
@@ -207,11 +226,15 @@ func (bs bootstrapService) Update(ctx context.Context, token string, cfg Config)
}
func (bs bootstrapService) UpdateCert(ctx context.Context, token, thingID, clientCert, clientKey, caCert string) (Config, error) {
owner, err := bs.identify(ctx, token)
user, err := bs.identify(ctx, token)
if err != nil {
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
cfg, err := bs.configs.UpdateCert(ctx, owner, thingID, clientCert, clientKey, caCert)
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.EditPermission, auth.ThingType, thingID); err != nil {
return Config{}, err
}
cfg, err := bs.configs.UpdateCert(ctx, user.GetDomainId(), thingID, clientCert, clientKey, caCert)
if err != nil {
return Config{}, errors.Wrap(errUpdateCert, err)
}
@@ -219,12 +242,16 @@ func (bs bootstrapService) UpdateCert(ctx context.Context, token, thingID, clien
}
func (bs bootstrapService) UpdateConnections(ctx context.Context, token, id string, connections []string) error {
owner, err := bs.identify(ctx, token)
user, err := bs.identify(ctx, token)
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
cfg, err := bs.configs.RetrieveByID(ctx, owner, id)
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.EditPermission, auth.ThingType, id); err != nil {
return err
}
cfg, err := bs.configs.RetrieveByID(ctx, user.GetDomainId(), id)
if err != nil {
return errors.Wrap(errUpdateConnections, err)
}
@@ -232,7 +259,7 @@ func (bs bootstrapService) UpdateConnections(ctx context.Context, token, id stri
add, remove := bs.updateList(cfg, connections)
// Check if channels exist. This is the way to prevent fetching channels that already exist.
existing, err := bs.configs.ListExisting(ctx, owner, connections)
existing, err := bs.configs.ListExisting(ctx, user.GetDomainId(), connections)
if err != nil {
return errors.Wrap(errUpdateConnections, err)
}
@@ -268,26 +295,83 @@ func (bs bootstrapService) UpdateConnections(ctx context.Context, token, id stri
return ErrThings
}
}
if err := bs.configs.UpdateConnections(ctx, owner, id, channels, connections); err != nil {
if err := bs.configs.UpdateConnections(ctx, user.GetDomainId(), id, channels, connections); err != nil {
return errors.Wrap(errUpdateConnections, err)
}
return nil
}
func (bs bootstrapService) listClientIDs(ctx context.Context, userID string) ([]string, error) {
tids, err := bs.auth.ListAllObjects(ctx, &magistrala.ListObjectsReq{
SubjectType: auth.UserType,
Subject: userID,
Permission: auth.ViewPermission,
ObjectType: auth.ThingType,
})
if err != nil {
return nil, errors.Wrap(svcerr.ErrNotFound, err)
}
return tids.Policies, nil
}
func (bs bootstrapService) checkSuperAdmin(ctx context.Context, userID string) error {
res, err := bs.auth.Authorize(ctx, &magistrala.AuthorizeReq{
SubjectType: auth.UserType,
Subject: userID,
Permission: auth.AdminPermission,
ObjectType: auth.PlatformType,
Object: auth.MagistralaObject,
})
if err != nil {
return err
}
if !res.Authorized {
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil
}
func (bs bootstrapService) List(ctx context.Context, token string, filter Filter, offset, limit uint64) (ConfigsPage, error) {
owner, err := bs.identify(ctx, token)
user, err := bs.identify(ctx, token)
if err != nil {
return ConfigsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
return bs.configs.RetrieveAll(ctx, owner, filter, offset, limit), nil
if err := bs.checkSuperAdmin(ctx, user.GetId()); err == nil {
return bs.configs.RetrieveAll(ctx, user.GetDomainId(), []string{}, filter, offset, limit), nil
}
if _, err := bs.authorize(ctx, "", auth.UsersKind, user.GetId(), auth.AdminPermission, auth.DomainType, user.GetDomainId()); err == nil {
return bs.configs.RetrieveAll(ctx, user.GetDomainId(), []string{}, filter, offset, limit), nil
}
// Handle non-admin users
thingIDs, err := bs.listClientIDs(ctx, user.GetId())
if err != nil {
return ConfigsPage{}, errors.Wrap(svcerr.ErrNotFound, err)
}
if len(thingIDs) == 0 {
return ConfigsPage{
Total: 0,
Offset: offset,
Limit: limit,
Configs: []Config{},
}, nil
}
return bs.configs.RetrieveAll(ctx, user.GetDomainId(), thingIDs, filter, offset, limit), nil
}
func (bs bootstrapService) Remove(ctx context.Context, token, id string) error {
owner, err := bs.identify(ctx, token)
user, err := bs.identify(ctx, token)
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := bs.configs.Remove(ctx, owner, id); err != nil {
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.DeletePermission, auth.ThingType, id); err != nil {
return err
}
if err := bs.configs.Remove(ctx, user.GetDomainId(), id); err != nil {
return errors.Wrap(errRemoveBootstrap, err)
}
return nil
@@ -313,12 +397,12 @@ func (bs bootstrapService) Bootstrap(ctx context.Context, externalKey, externalI
}
func (bs bootstrapService) ChangeState(ctx context.Context, token, id string, state State) error {
owner, err := bs.identify(ctx, token)
user, err := bs.identify(ctx, token)
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
cfg, err := bs.configs.RetrieveByID(ctx, owner, id)
cfg, err := bs.configs.RetrieveByID(ctx, user.GetDomainId(), id)
if err != nil {
return errors.Wrap(errChangeState, err)
}
@@ -352,7 +436,7 @@ func (bs bootstrapService) ChangeState(ctx context.Context, token, id string, st
}
}
}
if err := bs.configs.ChangeState(ctx, owner, id, state); err != nil {
if err := bs.configs.ChangeState(ctx, user.GetDomainId(), id, state); err != nil {
return errors.Wrap(errChangeState, err)
}
return nil
@@ -393,13 +477,36 @@ func (bs bootstrapService) DisconnectThingHandler(ctx context.Context, channelID
return nil
}
func (bs bootstrapService) identify(ctx context.Context, token string) (string, error) {
func (bs bootstrapService) identify(ctx context.Context, token string) (*magistrala.IdentityRes, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
res, err := bs.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
return nil, errors.Wrap(svcerr.ErrAuthentication, err)
}
if res.GetId() == "" || res.GetDomainId() == "" {
return nil, errors.Wrap(svcerr.ErrAuthentication, err)
}
return res, nil
}
func (bs bootstrapService) authorize(ctx context.Context, domainID, subjKind, subj, perm, objType, obj string) (string, error) {
req := &magistrala.AuthorizeReq{
Domain: domainID,
SubjectType: auth.UserType,
SubjectKind: subjKind,
Subject: subj,
Permission: perm,
ObjectType: objType,
Object: obj,
}
res, err := bs.auth.Authorize(ctx, req)
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthorization, err)
}
if !res.GetAuthorized() {
return "", errors.Wrap(svcerr.ErrAuthorization, err)
}
return res.GetId(), nil
@@ -451,6 +558,7 @@ func (bs bootstrapService) connectionChannels(channels, existing []string, token
ID: ch.ID,
Name: ch.Name,
Metadata: ch.Metadata,
DomainID: ch.DomainID,
})
}
+834 -312
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -27,7 +27,7 @@ func New(svc bootstrap.Service, tracer trace.Tracer) bootstrap.Service {
func (tm *tracingMiddleware) Add(ctx context.Context, token string, cfg bootstrap.Config) (bootstrap.Config, error) {
ctx, span := tm.tracer.Start(ctx, "svc_register_client", trace.WithAttributes(
attribute.String("thing_id", cfg.ThingID),
attribute.String("owner", cfg.Owner),
attribute.String("domain_id ", cfg.DomainID),
attribute.String("name", cfg.Name),
attribute.String("external_id", cfg.ExternalID),
attribute.String("content", cfg.Content),
@@ -54,7 +54,7 @@ func (tm *tracingMiddleware) Update(ctx context.Context, token string, cfg boots
attribute.String("name", cfg.Name),
attribute.String("content", cfg.Content),
attribute.String("thing_id", cfg.ThingID),
attribute.String("owner", cfg.Owner),
attribute.String("domain_id ", cfg.DomainID),
))
defer span.End()
+1 -1
View File
@@ -42,7 +42,7 @@ type Client struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Tags []string `json:"tags,omitempty"`
Domain string `json:"domain,omitempty"`
Domain string `json:"domain_id,omitempty"`
Credentials Credentials `json:"credentials,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
CreatedAt time.Time `json:"created_at,omitempty"`