mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 07:40:17 +00:00
MG-1955 - Update Bootstrap service access control (#2199)
Signed-off-by: JeffMboya <jangina.mboya@gmail.com>
This commit is contained in:
+66
-100
@@ -328,25 +328,20 @@ func TestView(t *testing.T) {
|
||||
defer bs.Close()
|
||||
c := newConfig()
|
||||
|
||||
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
|
||||
saved, err := svc.Add(context.Background(), validToken, c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
svcCall.Unset()
|
||||
|
||||
var channels []channel
|
||||
for _, ch := range saved.Channels {
|
||||
for _, ch := range c.Channels {
|
||||
channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata})
|
||||
}
|
||||
|
||||
data := config{
|
||||
ThingID: saved.ThingID,
|
||||
ThingKey: saved.ThingKey,
|
||||
State: saved.State,
|
||||
ThingID: c.ThingID,
|
||||
ThingKey: c.ThingKey,
|
||||
State: c.State,
|
||||
Channels: channels,
|
||||
ExternalID: saved.ExternalID,
|
||||
ExternalKey: saved.ExternalKey,
|
||||
Name: saved.Name,
|
||||
Content: saved.Content,
|
||||
ExternalID: c.ExternalID,
|
||||
ExternalKey: c.ExternalKey,
|
||||
Name: c.Name,
|
||||
Content: c.Content,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
@@ -360,7 +355,7 @@ func TestView(t *testing.T) {
|
||||
{
|
||||
desc: "view a config with invalid token",
|
||||
auth: invalidToken,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
status: http.StatusUnauthorized,
|
||||
res: config{},
|
||||
err: svcerr.ErrAuthentication,
|
||||
@@ -368,7 +363,7 @@ func TestView(t *testing.T) {
|
||||
{
|
||||
desc: "view a config",
|
||||
auth: validToken,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
status: http.StatusOK,
|
||||
res: data,
|
||||
err: nil,
|
||||
@@ -384,15 +379,23 @@ func TestView(t *testing.T) {
|
||||
{
|
||||
desc: "view a config with an empty token",
|
||||
auth: "",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
status: http.StatusUnauthorized,
|
||||
res: config{},
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "view config without authorization",
|
||||
auth: validToken,
|
||||
id: c.ThingID,
|
||||
status: http.StatusForbidden,
|
||||
res: config{},
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("View", mock.Anything, mock.Anything, mock.Anything).Return(c, tc.err)
|
||||
svcCall := svc.On("View", mock.Anything, tc.auth, tc.id).Return(c, tc.err)
|
||||
req := testRequest{
|
||||
client: bs.Client(),
|
||||
method: http.MethodGet,
|
||||
@@ -422,11 +425,6 @@ func TestUpdate(t *testing.T) {
|
||||
defer bs.Close()
|
||||
c := newConfig()
|
||||
|
||||
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
|
||||
saved, err := svc.Add(context.Background(), validToken, c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
svcCall.Unset()
|
||||
|
||||
data := toJSON(updateReq)
|
||||
|
||||
cases := []struct {
|
||||
@@ -441,7 +439,7 @@ func TestUpdate(t *testing.T) {
|
||||
{
|
||||
desc: "update with invalid token",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: invalidToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
@@ -450,7 +448,7 @@ func TestUpdate(t *testing.T) {
|
||||
{
|
||||
desc: "update with an empty token",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: "",
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
@@ -459,7 +457,7 @@ func TestUpdate(t *testing.T) {
|
||||
{
|
||||
desc: "update a valid config",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusOK,
|
||||
@@ -468,7 +466,7 @@ func TestUpdate(t *testing.T) {
|
||||
{
|
||||
desc: "update a config with wrong content type",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: "",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
@@ -486,7 +484,7 @@ func TestUpdate(t *testing.T) {
|
||||
{
|
||||
desc: "update a config with invalid request format",
|
||||
req: "}",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
@@ -494,7 +492,7 @@ func TestUpdate(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "update a config with an empty request",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
req: "",
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
@@ -504,7 +502,7 @@ func TestUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svcCall.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
|
||||
svcCall := svc.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
|
||||
req := testRequest{
|
||||
client: bs.Client(),
|
||||
method: http.MethodPut,
|
||||
@@ -525,11 +523,6 @@ func TestUpdateCert(t *testing.T) {
|
||||
defer bs.Close()
|
||||
c := newConfig()
|
||||
|
||||
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
|
||||
saved, err := svc.Add(context.Background(), validToken, c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
svcCall.Unset()
|
||||
|
||||
data := toJSON(updateReq)
|
||||
|
||||
cases := []struct {
|
||||
@@ -544,7 +537,7 @@ func TestUpdateCert(t *testing.T) {
|
||||
{
|
||||
desc: "update with invalid token",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: invalidToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
@@ -553,7 +546,7 @@ func TestUpdateCert(t *testing.T) {
|
||||
{
|
||||
desc: "update with an empty token",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: "",
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
@@ -562,7 +555,7 @@ func TestUpdateCert(t *testing.T) {
|
||||
{
|
||||
desc: "update a valid config",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusOK,
|
||||
@@ -571,7 +564,7 @@ func TestUpdateCert(t *testing.T) {
|
||||
{
|
||||
desc: "update a config with wrong content type",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: "",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
@@ -589,7 +582,7 @@ func TestUpdateCert(t *testing.T) {
|
||||
{
|
||||
desc: "update a config with invalid request format",
|
||||
req: "}",
|
||||
id: saved.ThingKey,
|
||||
id: c.ThingKey,
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
@@ -597,7 +590,7 @@ func TestUpdateCert(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "update a config with an empty request",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
req: "",
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
@@ -627,12 +620,6 @@ func TestUpdateConnections(t *testing.T) {
|
||||
bs, svc := newBootstrapServer()
|
||||
defer bs.Close()
|
||||
c := newConfig()
|
||||
|
||||
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
|
||||
saved, err := svc.Add(context.Background(), validToken, c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
svcCall.Unset()
|
||||
|
||||
data := toJSON(updateReq)
|
||||
|
||||
invalidChannels := updateReq
|
||||
@@ -652,7 +639,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
{
|
||||
desc: "update connections with invalid token",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: invalidToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
@@ -661,7 +648,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
{
|
||||
desc: "update connections with an empty token",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: "",
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
@@ -670,7 +657,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
{
|
||||
desc: "update connections valid config",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusOK,
|
||||
@@ -679,7 +666,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
{
|
||||
desc: "update connections with wrong content type",
|
||||
req: data,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: "",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
@@ -697,7 +684,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
{
|
||||
desc: "update connections with invalid channels",
|
||||
req: wrongData,
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusNotFound,
|
||||
@@ -706,7 +693,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
{
|
||||
desc: "update a config with invalid request format",
|
||||
req: "}",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
@@ -714,7 +701,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "update a config with an empty request",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
req: "",
|
||||
auth: validToken,
|
||||
contentType: contentType,
|
||||
@@ -724,7 +711,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := svcCall.On("UpdateConnections", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
|
||||
repoCall := svc.On("UpdateConnections", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
|
||||
req := testRequest{
|
||||
client: bs.Client(),
|
||||
method: http.MethodPut,
|
||||
@@ -758,28 +745,22 @@ func TestList(t *testing.T) {
|
||||
c.Name = fmt.Sprintf("%s-%d", addName, i)
|
||||
c.ExternalKey = fmt.Sprintf("%s%s", addExternalKey, strconv.Itoa(i))
|
||||
|
||||
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
|
||||
saved, err := svc.Add(context.Background(), validToken, c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
svcCall.Unset()
|
||||
|
||||
var channels []channel
|
||||
for _, ch := range saved.Channels {
|
||||
for _, ch := range c.Channels {
|
||||
channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata})
|
||||
}
|
||||
s := config{
|
||||
ThingID: saved.ThingID,
|
||||
ThingKey: saved.ThingKey,
|
||||
ThingID: c.ThingID,
|
||||
ThingKey: c.ThingKey,
|
||||
Channels: channels,
|
||||
ExternalID: saved.ExternalID,
|
||||
ExternalKey: saved.ExternalKey,
|
||||
Name: saved.Name,
|
||||
Content: saved.Content,
|
||||
State: saved.State,
|
||||
ExternalID: c.ExternalID,
|
||||
ExternalKey: c.ExternalKey,
|
||||
Name: c.Name,
|
||||
Content: c.Content,
|
||||
State: c.State,
|
||||
}
|
||||
list[i] = s
|
||||
}
|
||||
|
||||
// Change state of first 20 elements for filtering tests.
|
||||
for i := 0; i < changedStateNum; i++ {
|
||||
state := bootstrap.Active
|
||||
@@ -1020,11 +1001,6 @@ func TestRemove(t *testing.T) {
|
||||
defer bs.Close()
|
||||
c := newConfig()
|
||||
|
||||
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
|
||||
saved, err := svc.Add(context.Background(), validToken, c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
svcCall.Unset()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
@@ -1034,14 +1010,14 @@ func TestRemove(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
desc: "remove with invalid token",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: invalidToken,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "remove with an empty token",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: "",
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
@@ -1055,7 +1031,7 @@ func TestRemove(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "remove config",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
status: http.StatusNoContent,
|
||||
err: nil,
|
||||
@@ -1089,16 +1065,11 @@ func TestBootstrap(t *testing.T) {
|
||||
defer bs.Close()
|
||||
c := newConfig()
|
||||
|
||||
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
|
||||
saved, err := svc.Add(context.Background(), validToken, c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
svcCall.Unset()
|
||||
|
||||
encExternKey, err := enc([]byte(c.ExternalKey))
|
||||
assert.Nil(t, err, fmt.Sprintf("Encrypting config expected to succeed: %s.\n", err))
|
||||
|
||||
var channels []channel
|
||||
for _, ch := range saved.Channels {
|
||||
for _, ch := range c.Channels {
|
||||
channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata})
|
||||
}
|
||||
|
||||
@@ -1111,13 +1082,13 @@ func TestBootstrap(t *testing.T) {
|
||||
ClientKey string `json:"client_key"`
|
||||
CACert string `json:"ca_cert"`
|
||||
}{
|
||||
ThingID: saved.ThingID,
|
||||
ThingKey: saved.ThingKey,
|
||||
ThingID: c.ThingID,
|
||||
ThingKey: c.ThingKey,
|
||||
Channels: channels,
|
||||
Content: saved.Content,
|
||||
ClientCert: saved.ClientCert,
|
||||
ClientKey: saved.ClientKey,
|
||||
CACert: saved.CACert,
|
||||
Content: c.Content,
|
||||
ClientCert: c.ClientCert,
|
||||
ClientKey: c.ClientKey,
|
||||
CACert: c.CACert,
|
||||
}
|
||||
|
||||
data := toJSON(s)
|
||||
@@ -1225,11 +1196,6 @@ func TestChangeState(t *testing.T) {
|
||||
defer bs.Close()
|
||||
c := newConfig()
|
||||
|
||||
svcCall := svc.On("Add", context.Background(), mock.Anything, mock.Anything).Return(c, nil)
|
||||
saved, err := svc.Add(context.Background(), validToken, c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
svcCall.Unset()
|
||||
|
||||
inactive := fmt.Sprintf("{\"state\": %d}", bootstrap.Inactive)
|
||||
active := fmt.Sprintf("{\"state\": %d}", bootstrap.Active)
|
||||
|
||||
@@ -1244,7 +1210,7 @@ func TestChangeState(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
desc: "change state with invalid token",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: invalidToken,
|
||||
state: active,
|
||||
contentType: contentType,
|
||||
@@ -1253,7 +1219,7 @@ func TestChangeState(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "change state with an empty token",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: "",
|
||||
state: active,
|
||||
contentType: contentType,
|
||||
@@ -1262,7 +1228,7 @@ func TestChangeState(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "change state with invalid content type",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
state: active,
|
||||
contentType: "",
|
||||
@@ -1271,7 +1237,7 @@ func TestChangeState(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "change state to active",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
state: active,
|
||||
contentType: contentType,
|
||||
@@ -1280,7 +1246,7 @@ func TestChangeState(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "change state to inactive",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
state: inactive,
|
||||
contentType: contentType,
|
||||
@@ -1298,7 +1264,7 @@ func TestChangeState(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "change state to invalid value",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
state: fmt.Sprintf("{\"state\": %d}", -3),
|
||||
contentType: contentType,
|
||||
@@ -1307,7 +1273,7 @@ func TestChangeState(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "change state with invalid data",
|
||||
id: saved.ThingID,
|
||||
id: c.ThingID,
|
||||
auth: validToken,
|
||||
state: "",
|
||||
contentType: contentType,
|
||||
|
||||
@@ -36,6 +36,16 @@ func (req addReq) validate() error {
|
||||
return apiutil.ErrBearerKey
|
||||
}
|
||||
|
||||
if len(req.Channels) == 0 {
|
||||
return apiutil.ErrEmptyList
|
||||
}
|
||||
|
||||
for _, channel := range req.Channels {
|
||||
if channel == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,23 +8,39 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/apiutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
channel1 = testsutil.GenerateUUID(&testing.T{})
|
||||
channel2 = testsutil.GenerateUUID(&testing.T{})
|
||||
)
|
||||
|
||||
func TestAddReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
externalID string
|
||||
externalKey string
|
||||
channels []string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "empty key",
|
||||
desc: "valid request",
|
||||
token: "token",
|
||||
externalID: "external-id",
|
||||
externalKey: "external-key",
|
||||
channels: []string{channel1, channel2},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty token",
|
||||
token: "",
|
||||
externalID: "external-id",
|
||||
externalKey: "external-key",
|
||||
channels: []string{channel1, channel2},
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
@@ -32,6 +48,7 @@ func TestAddReqValidation(t *testing.T) {
|
||||
token: "token",
|
||||
externalID: "",
|
||||
externalKey: "external-key",
|
||||
channels: []string{channel1, channel2},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
@@ -39,8 +56,33 @@ func TestAddReqValidation(t *testing.T) {
|
||||
token: "token",
|
||||
externalID: "external-id",
|
||||
externalKey: "",
|
||||
channels: []string{channel1, channel2},
|
||||
err: apiutil.ErrBearerKey,
|
||||
},
|
||||
{
|
||||
desc: "empty external key and external ID",
|
||||
token: "token",
|
||||
externalID: "",
|
||||
externalKey: "",
|
||||
channels: []string{channel1, channel2},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "empty channels",
|
||||
token: "token",
|
||||
externalID: "external-id",
|
||||
externalKey: "external-key",
|
||||
channels: []string{},
|
||||
err: apiutil.ErrEmptyList,
|
||||
},
|
||||
{
|
||||
desc: "empty channel value",
|
||||
token: "token",
|
||||
externalID: "external-id",
|
||||
externalKey: "external-key",
|
||||
channels: []string{channel1, ""},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
@@ -48,6 +90,7 @@ func TestAddReqValidation(t *testing.T) {
|
||||
token: tc.token,
|
||||
ExternalID: tc.externalID,
|
||||
ExternalKey: tc.externalKey,
|
||||
Channels: tc.channels,
|
||||
}
|
||||
|
||||
err := req.validate()
|
||||
@@ -93,6 +136,12 @@ func TestUpdateReqValidation(t *testing.T) {
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
token: "token",
|
||||
id: "id",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty token",
|
||||
token: "",
|
||||
|
||||
+10
-10
@@ -17,7 +17,7 @@ import (
|
||||
// MGChannels is a list of Magistrala Channels corresponding Magistrala Thing connects to.
|
||||
type Config struct {
|
||||
ThingID string `json:"thing_id"`
|
||||
Owner string `json:"owner,omitempty"`
|
||||
DomainID string `json:"domain_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ClientCert string `json:"client_cert,omitempty"`
|
||||
ClientKey string `json:"client_key,omitempty"`
|
||||
@@ -35,7 +35,7 @@ type Channel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Owner string `json:"owner_id"`
|
||||
DomainID string `json:"domain_id"`
|
||||
Parent string `json:"parent_id,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
@@ -69,11 +69,11 @@ type ConfigRepository interface {
|
||||
|
||||
// RetrieveByID retrieves the Config having the provided identifier, that is owned
|
||||
// by the specified user.
|
||||
RetrieveByID(ctx context.Context, owner, id string) (Config, error)
|
||||
RetrieveByID(ctx context.Context, domainID, id string) (Config, error)
|
||||
|
||||
// RetrieveAll retrieves a subset of Configs that are owned
|
||||
// by the specific user, with given filter parameters.
|
||||
RetrieveAll(ctx context.Context, owner string, filter Filter, offset, limit uint64) ConfigsPage
|
||||
RetrieveAll(ctx context.Context, domainID string, thingIDs []string, filter Filter, offset, limit uint64) ConfigsPage
|
||||
|
||||
// RetrieveByExternalID returns Config for given external ID.
|
||||
RetrieveByExternalID(ctx context.Context, externalID string) (Config, error)
|
||||
@@ -82,23 +82,23 @@ type ConfigRepository interface {
|
||||
// to indicate operation failure.
|
||||
Update(ctx context.Context, cfg Config) error
|
||||
|
||||
// UpdateCerts updates and returns an existing Config certificate and owner.
|
||||
// UpdateCerts updates and returns an existing Config certificate and domainID.
|
||||
// A non-nil error is returned to indicate operation failure.
|
||||
UpdateCert(ctx context.Context, owner, thingID, clientCert, clientKey, caCert string) (Config, error)
|
||||
UpdateCert(ctx context.Context, domainID, thingID, clientCert, clientKey, caCert string) (Config, error)
|
||||
|
||||
// UpdateConnections updates a list of Channels the Config is connected to
|
||||
// adding new Channels if needed.
|
||||
UpdateConnections(ctx context.Context, owner, id string, channels []Channel, connections []string) error
|
||||
UpdateConnections(ctx context.Context, domainID, id string, channels []Channel, connections []string) error
|
||||
|
||||
// Remove removes the Config having the provided identifier, that is owned
|
||||
// by the specified user.
|
||||
Remove(ctx context.Context, owner, id string) error
|
||||
Remove(ctx context.Context, domainID, id string) error
|
||||
|
||||
// ChangeState changes of the Config, that is owned by the specific user.
|
||||
ChangeState(ctx context.Context, owner, id string, state State) error
|
||||
ChangeState(ctx context.Context, domainID, id string, state State) error
|
||||
|
||||
// ListExisting retrieves those channels from the given list that exist in DB.
|
||||
ListExisting(ctx context.Context, owner string, ids []string) ([]Channel, error)
|
||||
ListExisting(ctx context.Context, domainID string, ids []string) ([]Channel, error)
|
||||
|
||||
// Methods RemoveThing, UpdateChannel, and RemoveChannel are related to
|
||||
// event sourcing. That's why these methods surpass ownership check.
|
||||
|
||||
@@ -58,8 +58,8 @@ func (ce configEvent) Encode() (map[string]interface{}, error) {
|
||||
if ce.Content != "" {
|
||||
val["content"] = ce.Content
|
||||
}
|
||||
if ce.Owner != "" {
|
||||
val["owner"] = ce.Owner
|
||||
if ce.DomainID != "" {
|
||||
val["domain_id "] = ce.DomainID
|
||||
}
|
||||
if ce.Name != "" {
|
||||
val["name"] = ce.Name
|
||||
@@ -143,8 +143,8 @@ func (be bootstrapEvent) Encode() (map[string]interface{}, error) {
|
||||
if be.Content != "" {
|
||||
val["content"] = be.Content
|
||||
}
|
||||
if be.Owner != "" {
|
||||
val["owner"] = be.Owner
|
||||
if be.DomainID != "" {
|
||||
val["domain_id "] = be.DomainID
|
||||
}
|
||||
if be.Name != "" {
|
||||
val["name"] = be.Name
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+35
-35
@@ -17,9 +17,9 @@ type ConfigRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// ChangeState provides a mock function with given fields: ctx, owner, id, state
|
||||
func (_m *ConfigRepository) ChangeState(ctx context.Context, owner string, id string, state bootstrap.State) error {
|
||||
ret := _m.Called(ctx, owner, id, state)
|
||||
// ChangeState provides a mock function with given fields: ctx, domainID, id, state
|
||||
func (_m *ConfigRepository) ChangeState(ctx context.Context, domainID string, id string, state bootstrap.State) error {
|
||||
ret := _m.Called(ctx, domainID, id, state)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ChangeState")
|
||||
@@ -27,7 +27,7 @@ func (_m *ConfigRepository) ChangeState(ctx context.Context, owner string, id st
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, bootstrap.State) error); ok {
|
||||
r0 = rf(ctx, owner, id, state)
|
||||
r0 = rf(ctx, domainID, id, state)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
@@ -71,9 +71,9 @@ func (_m *ConfigRepository) DisconnectThing(ctx context.Context, channelID strin
|
||||
return r0
|
||||
}
|
||||
|
||||
// ListExisting provides a mock function with given fields: ctx, owner, ids
|
||||
func (_m *ConfigRepository) ListExisting(ctx context.Context, owner string, ids []string) ([]bootstrap.Channel, error) {
|
||||
ret := _m.Called(ctx, owner, ids)
|
||||
// ListExisting provides a mock function with given fields: ctx, domainID, ids
|
||||
func (_m *ConfigRepository) ListExisting(ctx context.Context, domainID string, ids []string) ([]bootstrap.Channel, error) {
|
||||
ret := _m.Called(ctx, domainID, ids)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListExisting")
|
||||
@@ -82,10 +82,10 @@ func (_m *ConfigRepository) ListExisting(ctx context.Context, owner string, ids
|
||||
var r0 []bootstrap.Channel
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, []string) ([]bootstrap.Channel, error)); ok {
|
||||
return rf(ctx, owner, ids)
|
||||
return rf(ctx, domainID, ids)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, []string) []bootstrap.Channel); ok {
|
||||
r0 = rf(ctx, owner, ids)
|
||||
r0 = rf(ctx, domainID, ids)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]bootstrap.Channel)
|
||||
@@ -93,7 +93,7 @@ func (_m *ConfigRepository) ListExisting(ctx context.Context, owner string, ids
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok {
|
||||
r1 = rf(ctx, owner, ids)
|
||||
r1 = rf(ctx, domainID, ids)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@@ -101,9 +101,9 @@ func (_m *ConfigRepository) ListExisting(ctx context.Context, owner string, ids
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Remove provides a mock function with given fields: ctx, owner, id
|
||||
func (_m *ConfigRepository) Remove(ctx context.Context, owner string, id string) error {
|
||||
ret := _m.Called(ctx, owner, id)
|
||||
// Remove provides a mock function with given fields: ctx, domainID, id
|
||||
func (_m *ConfigRepository) Remove(ctx context.Context, domainID string, id string) error {
|
||||
ret := _m.Called(ctx, domainID, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Remove")
|
||||
@@ -111,7 +111,7 @@ func (_m *ConfigRepository) Remove(ctx context.Context, owner string, id string)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
|
||||
r0 = rf(ctx, owner, id)
|
||||
r0 = rf(ctx, domainID, id)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
@@ -155,17 +155,17 @@ func (_m *ConfigRepository) RemoveThing(ctx context.Context, id string) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
// RetrieveAll provides a mock function with given fields: ctx, owner, filter, offset, limit
|
||||
func (_m *ConfigRepository) RetrieveAll(ctx context.Context, owner string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage {
|
||||
ret := _m.Called(ctx, owner, filter, offset, limit)
|
||||
// RetrieveAll provides a mock function with given fields: ctx, domainID, thingIDs, filter, offset, limit
|
||||
func (_m *ConfigRepository) RetrieveAll(ctx context.Context, domainID string, thingIDs []string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage {
|
||||
ret := _m.Called(ctx, domainID, thingIDs, filter, offset, limit)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveAll")
|
||||
}
|
||||
|
||||
var r0 bootstrap.ConfigsPage
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, bootstrap.Filter, uint64, uint64) bootstrap.ConfigsPage); ok {
|
||||
r0 = rf(ctx, owner, filter, offset, limit)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, []string, bootstrap.Filter, uint64, uint64) bootstrap.ConfigsPage); ok {
|
||||
r0 = rf(ctx, domainID, thingIDs, filter, offset, limit)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bootstrap.ConfigsPage)
|
||||
}
|
||||
@@ -201,9 +201,9 @@ func (_m *ConfigRepository) RetrieveByExternalID(ctx context.Context, externalID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveByID provides a mock function with given fields: ctx, owner, id
|
||||
func (_m *ConfigRepository) RetrieveByID(ctx context.Context, owner string, id string) (bootstrap.Config, error) {
|
||||
ret := _m.Called(ctx, owner, id)
|
||||
// RetrieveByID provides a mock function with given fields: ctx, domainID, id
|
||||
func (_m *ConfigRepository) RetrieveByID(ctx context.Context, domainID string, id string) (bootstrap.Config, error) {
|
||||
ret := _m.Called(ctx, domainID, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveByID")
|
||||
@@ -212,16 +212,16 @@ func (_m *ConfigRepository) RetrieveByID(ctx context.Context, owner string, id s
|
||||
var r0 bootstrap.Config
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) (bootstrap.Config, error)); ok {
|
||||
return rf(ctx, owner, id)
|
||||
return rf(ctx, domainID, id)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) bootstrap.Config); ok {
|
||||
r0 = rf(ctx, owner, id)
|
||||
r0 = rf(ctx, domainID, id)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bootstrap.Config)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = rf(ctx, owner, id)
|
||||
r1 = rf(ctx, domainID, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@@ -275,9 +275,9 @@ func (_m *ConfigRepository) Update(ctx context.Context, cfg bootstrap.Config) er
|
||||
return r0
|
||||
}
|
||||
|
||||
// UpdateCert provides a mock function with given fields: ctx, owner, thingID, clientCert, clientKey, caCert
|
||||
func (_m *ConfigRepository) UpdateCert(ctx context.Context, owner string, thingID string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error) {
|
||||
ret := _m.Called(ctx, owner, thingID, clientCert, clientKey, caCert)
|
||||
// UpdateCert provides a mock function with given fields: ctx, domainID, thingID, clientCert, clientKey, caCert
|
||||
func (_m *ConfigRepository) UpdateCert(ctx context.Context, domainID string, thingID string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error) {
|
||||
ret := _m.Called(ctx, domainID, thingID, clientCert, clientKey, caCert)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateCert")
|
||||
@@ -286,16 +286,16 @@ func (_m *ConfigRepository) UpdateCert(ctx context.Context, owner string, thingI
|
||||
var r0 bootstrap.Config
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (bootstrap.Config, error)); ok {
|
||||
return rf(ctx, owner, thingID, clientCert, clientKey, caCert)
|
||||
return rf(ctx, domainID, thingID, clientCert, clientKey, caCert)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) bootstrap.Config); ok {
|
||||
r0 = rf(ctx, owner, thingID, clientCert, clientKey, caCert)
|
||||
r0 = rf(ctx, domainID, thingID, clientCert, clientKey, caCert)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bootstrap.Config)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) error); ok {
|
||||
r1 = rf(ctx, owner, thingID, clientCert, clientKey, caCert)
|
||||
r1 = rf(ctx, domainID, thingID, clientCert, clientKey, caCert)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@@ -321,9 +321,9 @@ func (_m *ConfigRepository) UpdateChannel(ctx context.Context, c bootstrap.Chann
|
||||
return r0
|
||||
}
|
||||
|
||||
// UpdateConnections provides a mock function with given fields: ctx, owner, id, channels, connections
|
||||
func (_m *ConfigRepository) UpdateConnections(ctx context.Context, owner string, id string, channels []bootstrap.Channel, connections []string) error {
|
||||
ret := _m.Called(ctx, owner, id, channels, connections)
|
||||
// UpdateConnections provides a mock function with given fields: ctx, domainID, id, channels, connections
|
||||
func (_m *ConfigRepository) UpdateConnections(ctx context.Context, domainID string, id string, channels []bootstrap.Channel, connections []string) error {
|
||||
ret := _m.Called(ctx, domainID, id, channels, connections)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateConnections")
|
||||
@@ -331,7 +331,7 @@ func (_m *ConfigRepository) UpdateConnections(ctx context.Context, owner string,
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, []bootstrap.Channel, []string) error); ok {
|
||||
r0 = rf(ctx, owner, id, channels, connections)
|
||||
r0 = rf(ctx, domainID, id, channels, connections)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
+107
-102
@@ -49,8 +49,8 @@ func NewConfigRepository(db postgres.Database, log *slog.Logger) bootstrap.Confi
|
||||
}
|
||||
|
||||
func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsConnIDs []string) (thingID string, err error) {
|
||||
q := `INSERT INTO configs (magistrala_thing, owner, name, client_cert, client_key, ca_cert, magistrala_key, external_id, external_key, content, state)
|
||||
VALUES (:magistrala_thing, :owner, :name, :client_cert, :client_key, :ca_cert, :magistrala_key, :external_id, :external_key, :content, :state)`
|
||||
q := `INSERT INTO configs (magistrala_thing, domain_id, name, client_cert, client_key, ca_cert, magistrala_key, external_id, external_key, content, state)
|
||||
VALUES (:magistrala_thing, :domain_id, :name, :client_cert, :client_key, :ca_cert, :magistrala_key, :external_id, :external_key, :content, :state)`
|
||||
|
||||
tx, err := cr.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -60,7 +60,7 @@ func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsCo
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = cr.rollback(err, tx)
|
||||
err = cr.rollback("Save method", err, tx)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -68,13 +68,13 @@ func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsCo
|
||||
switch pgErr := err.(type) {
|
||||
case *pgconn.PgError:
|
||||
if pgErr.Code == pgerrcode.UniqueViolation {
|
||||
return "", repoerr.ErrConflict
|
||||
err = repoerr.ErrConflict
|
||||
}
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := insertChannels(ctx, cfg.Owner, cfg.Channels, tx); err != nil {
|
||||
if err := insertChannels(cfg.DomainID, cfg.Channels, tx); err != nil {
|
||||
return "", errors.Wrap(errSaveChannels, err)
|
||||
}
|
||||
|
||||
@@ -82,20 +82,21 @@ func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsCo
|
||||
return "", errors.Wrap(errSaveConnections, err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", err
|
||||
if commitErr := tx.Commit(); commitErr != nil {
|
||||
return "", commitErr
|
||||
}
|
||||
|
||||
return cfg.ThingID, nil
|
||||
}
|
||||
|
||||
func (cr configRepository) RetrieveByID(ctx context.Context, owner, id string) (bootstrap.Config, error) {
|
||||
func (cr configRepository) RetrieveByID(ctx context.Context, domainID, id string) (bootstrap.Config, error) {
|
||||
q := `SELECT magistrala_thing, magistrala_key, external_id, external_key, name, content, state, client_cert, ca_cert
|
||||
FROM configs
|
||||
WHERE magistrala_thing = :magistrala_thing AND owner = :owner`
|
||||
WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id`
|
||||
|
||||
dbcfg := dbConfig{
|
||||
ThingID: id,
|
||||
Owner: owner,
|
||||
ThingID: id,
|
||||
DomainID: domainID,
|
||||
}
|
||||
row, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
@@ -116,8 +117,8 @@ func (cr configRepository) RetrieveByID(ctx context.Context, owner, id string) (
|
||||
|
||||
q = `SELECT magistrala_channel, name, metadata FROM channels ch
|
||||
INNER JOIN connections conn
|
||||
ON ch.magistrala_channel = conn.channel_id AND ch.owner = conn.config_owner
|
||||
WHERE conn.config_id = :magistrala_thing AND conn.config_owner = :owner`
|
||||
ON ch.magistrala_channel = conn.channel_id AND ch.domain_id = conn.domain_id
|
||||
WHERE conn.config_id = :magistrala_thing AND conn.domain_id = :domain_id`
|
||||
|
||||
rows, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
@@ -133,7 +134,7 @@ func (cr configRepository) RetrieveByID(ctx context.Context, owner, id string) (
|
||||
cr.log.Error(fmt.Sprintf("Failed to read connected thing due to %s", err))
|
||||
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
dbch.Owner = nullString(dbcfg.Owner)
|
||||
dbch.DomainID = nullString(dbcfg.DomainID)
|
||||
|
||||
ch, err := toChannel(dbch)
|
||||
if err != nil {
|
||||
@@ -148,12 +149,12 @@ func (cr configRepository) RetrieveByID(ctx context.Context, owner, id string) (
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (cr configRepository) RetrieveAll(ctx context.Context, owner string, filter bootstrap.Filter, offset, limit uint64) bootstrap.ConfigsPage {
|
||||
search, params := cr.retrieveAll(owner, filter)
|
||||
func (cr configRepository) RetrieveAll(ctx context.Context, domainID string, thingIDs []string, filter bootstrap.Filter, offset, limit uint64) bootstrap.ConfigsPage {
|
||||
search, params := buildRetrieveQueryParams(domainID, thingIDs, filter)
|
||||
n := len(params)
|
||||
|
||||
q := `SELECT magistrala_thing, magistrala_key, external_id, external_key, name, content, state
|
||||
FROM configs %s ORDER BY magistrala_thing LIMIT $%d OFFSET $%d`
|
||||
FROM configs %s ORDER BY magistrala_thing LIMIT $%d OFFSET $%d`
|
||||
q = fmt.Sprintf(q, search, n+1, n+2)
|
||||
|
||||
rows, err := cr.db.QueryContext(ctx, q, append(params, limit, offset)...)
|
||||
@@ -167,7 +168,7 @@ func (cr configRepository) RetrieveAll(ctx context.Context, owner string, filter
|
||||
configs := []bootstrap.Config{}
|
||||
|
||||
for rows.Next() {
|
||||
c := bootstrap.Config{Owner: owner}
|
||||
c := bootstrap.Config{DomainID: domainID}
|
||||
if err := rows.Scan(&c.ThingID, &c.ThingKey, &c.ExternalID, &c.ExternalKey, &name, &content, &c.State); err != nil {
|
||||
cr.log.Error(fmt.Sprintf("Failed to read retrieved config due to %s", err))
|
||||
return bootstrap.ConfigsPage{}
|
||||
@@ -195,7 +196,7 @@ func (cr configRepository) RetrieveAll(ctx context.Context, owner string, filter
|
||||
}
|
||||
|
||||
func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID string) (bootstrap.Config, error) {
|
||||
q := `SELECT magistrala_thing, magistrala_key, external_key, owner, name, client_cert, client_key, ca_cert, content, state
|
||||
q := `SELECT magistrala_thing, magistrala_key, external_key, domain_id, name, client_cert, client_key, ca_cert, content, state
|
||||
FROM configs
|
||||
WHERE external_id = :external_id`
|
||||
dbcfg := dbConfig{
|
||||
@@ -219,9 +220,9 @@ func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID
|
||||
}
|
||||
|
||||
q = `SELECT magistrala_channel, name, metadata FROM channels ch
|
||||
INNER JOIN connections conn
|
||||
ON ch.magistrala_channel = conn.channel_id AND ch.owner = conn.config_owner
|
||||
WHERE conn.config_id = :magistrala_thing AND conn.config_owner = :owner`
|
||||
INNER JOIN connections conn
|
||||
ON ch.magistrala_channel = conn.channel_id AND ch.domain_id = conn.domain_id
|
||||
WHERE conn.config_id = :magistrala_thing AND conn.domain_id = :domain_id`
|
||||
|
||||
rows, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
@@ -254,13 +255,13 @@ func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID
|
||||
}
|
||||
|
||||
func (cr configRepository) Update(ctx context.Context, cfg bootstrap.Config) error {
|
||||
q := `UPDATE configs SET name = :name, content = :content WHERE magistrala_thing = :magistrala_thing AND owner = :owner `
|
||||
q := `UPDATE configs SET name = :name, content = :content WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id `
|
||||
|
||||
dbcfg := dbConfig{
|
||||
Name: nullString(cfg.Name),
|
||||
Content: nullString(cfg.Content),
|
||||
ThingID: cfg.ThingID,
|
||||
Owner: cfg.Owner,
|
||||
Name: nullString(cfg.Name),
|
||||
Content: nullString(cfg.Content),
|
||||
ThingID: cfg.ThingID,
|
||||
DomainID: cfg.DomainID,
|
||||
}
|
||||
|
||||
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
|
||||
@@ -280,14 +281,14 @@ func (cr configRepository) Update(ctx context.Context, cfg bootstrap.Config) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr configRepository) UpdateCert(ctx context.Context, owner, thingID, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
|
||||
q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE magistrala_thing = :magistrala_thing AND owner = :owner
|
||||
func (cr configRepository) UpdateCert(ctx context.Context, domainID, thingID, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
|
||||
q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id
|
||||
RETURNING magistrala_thing, client_cert, client_key, ca_cert`
|
||||
|
||||
dbcfg := dbConfig{
|
||||
ThingID: thingID,
|
||||
ClientCert: nullString(clientCert),
|
||||
Owner: owner,
|
||||
DomainID: domainID,
|
||||
ClientKey: nullString(clientKey),
|
||||
CaCert: nullString(caCert),
|
||||
}
|
||||
@@ -309,7 +310,7 @@ func (cr configRepository) UpdateCert(ctx context.Context, owner, thingID, clien
|
||||
return toConfig(dbcfg), nil
|
||||
}
|
||||
|
||||
func (cr configRepository) UpdateConnections(ctx context.Context, owner, id string, channels []bootstrap.Channel, connections []string) error {
|
||||
func (cr configRepository) UpdateConnections(ctx context.Context, domainID, id string, channels []bootstrap.Channel, connections []string) (err error) {
|
||||
tx, err := cr.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
@@ -317,35 +318,37 @@ func (cr configRepository) UpdateConnections(ctx context.Context, owner, id stri
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = cr.rollback(err, tx)
|
||||
err = cr.rollback("UpdateConnections method", err, tx)
|
||||
} else {
|
||||
if commitErr := tx.Commit(); commitErr != nil {
|
||||
err = commitErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := insertChannels(ctx, owner, channels, tx); err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
if err = insertChannels(domainID, channels, tx); err != nil {
|
||||
err = errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := updateConnections(ctx, owner, id, connections, tx); err != nil {
|
||||
if err = updateConnections(domainID, id, connections, tx); err != nil {
|
||||
if e, ok := err.(*pgconn.PgError); ok {
|
||||
if e.Code == pgerrcode.ForeignKeyViolation {
|
||||
return repoerr.ErrNotFound
|
||||
err = repoerr.ErrNotFound
|
||||
}
|
||||
}
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
err = errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr configRepository) Remove(ctx context.Context, owner, id string) error {
|
||||
q := `DELETE FROM configs WHERE magistrala_thing = :magistrala_thing AND owner = :owner`
|
||||
func (cr configRepository) Remove(ctx context.Context, domainID, id string) error {
|
||||
q := `DELETE FROM configs WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id`
|
||||
dbcfg := dbConfig{
|
||||
ThingID: id,
|
||||
Owner: owner,
|
||||
ThingID: id,
|
||||
DomainID: domainID,
|
||||
}
|
||||
|
||||
if _, err := cr.db.NamedExecContext(ctx, q, dbcfg); err != nil {
|
||||
@@ -359,13 +362,13 @@ func (cr configRepository) Remove(ctx context.Context, owner, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr configRepository) ChangeState(ctx context.Context, owner, id string, state bootstrap.State) error {
|
||||
q := `UPDATE configs SET state = :state WHERE magistrala_thing = :magistrala_thing AND owner = :owner;`
|
||||
func (cr configRepository) ChangeState(ctx context.Context, domainID, id string, state bootstrap.State) error {
|
||||
q := `UPDATE configs SET state = :state WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id;`
|
||||
|
||||
dbcfg := dbConfig{
|
||||
ThingID: id,
|
||||
State: state,
|
||||
Owner: owner,
|
||||
ThingID: id,
|
||||
State: state,
|
||||
DomainID: domainID,
|
||||
}
|
||||
|
||||
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
|
||||
@@ -385,7 +388,7 @@ func (cr configRepository) ChangeState(ctx context.Context, owner, id string, st
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr configRepository) ListExisting(ctx context.Context, owner string, ids []string) ([]bootstrap.Channel, error) {
|
||||
func (cr configRepository) ListExisting(ctx context.Context, domainID string, ids []string) ([]bootstrap.Channel, error) {
|
||||
var channels []bootstrap.Channel
|
||||
if len(ids) == 0 {
|
||||
return channels, nil
|
||||
@@ -396,8 +399,8 @@ func (cr configRepository) ListExisting(ctx context.Context, owner string, ids [
|
||||
return []bootstrap.Channel{}, err
|
||||
}
|
||||
|
||||
q := "SELECT magistrala_channel, name, metadata FROM channels WHERE owner = $1 AND magistrala_channel = ANY ($2)"
|
||||
rows, err := cr.db.QueryxContext(ctx, q, owner, chans)
|
||||
q := "SELECT magistrala_channel, name, metadata FROM channels WHERE domain_id = $1 AND magistrala_channel = ANY ($2)"
|
||||
rows, err := cr.db.QueryxContext(ctx, q, domainID, chans)
|
||||
if err != nil {
|
||||
return []bootstrap.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
@@ -472,62 +475,66 @@ func (cr configRepository) ConnectThing(ctx context.Context, channelID, thingID
|
||||
func (cr configRepository) DisconnectThing(ctx context.Context, channelID, thingID string) error {
|
||||
q := `UPDATE configs SET state = $1 WHERE EXISTS (
|
||||
SELECT 1 FROM connections WHERE config_id = $2 AND channel_id = $3)`
|
||||
result, err := cr.db.ExecContext(ctx, q, bootstrap.Inactive, thingID, channelID)
|
||||
_, err := cr.db.ExecContext(ctx, q, bootstrap.Inactive, thingID, channelID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errDisconnectThing, err)
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr configRepository) retrieveAll(owner string, filter bootstrap.Filter) (string, []interface{}) {
|
||||
template := `WHERE owner = $1 %s`
|
||||
params := []interface{}{owner}
|
||||
// One empty string so that strings Join works if only one filter is applied.
|
||||
queries := []string{""}
|
||||
// Since owner is the first param, start from 2.
|
||||
counter := 2
|
||||
func buildRetrieveQueryParams(domainID string, thingIDs []string, filter bootstrap.Filter) (string, []interface{}) {
|
||||
params := []interface{}{}
|
||||
queries := []string{}
|
||||
|
||||
if len(thingIDs) != 0 {
|
||||
queries = append(queries, fmt.Sprintf("magistrala_thing IN ('%s')", strings.Join(thingIDs, "','")))
|
||||
} else if domainID != "" {
|
||||
params = append(params, domainID)
|
||||
queries = append(queries, fmt.Sprintf("domain_id = $%d", len(params)))
|
||||
}
|
||||
|
||||
// Adjust the starting point for placeholders based on the current length of params
|
||||
counter := len(params) + 1
|
||||
for k, v := range filter.FullMatch {
|
||||
queries = append(queries, fmt.Sprintf("%s = $%d", k, counter))
|
||||
params = append(params, v)
|
||||
queries = append(queries, fmt.Sprintf("%s = $%d", k, counter))
|
||||
counter++
|
||||
}
|
||||
for k, v := range filter.PartialMatch {
|
||||
queries = append(queries, fmt.Sprintf("LOWER(%s) LIKE '%%' || $%d || '%%'", k, counter))
|
||||
params = append(params, v)
|
||||
queries = append(queries, fmt.Sprintf("LOWER(%s) LIKE '%%' || $%d || '%%'", k, counter))
|
||||
counter++
|
||||
}
|
||||
|
||||
f := strings.Join(queries, " AND ")
|
||||
|
||||
return fmt.Sprintf(template, f), params
|
||||
if len(queries) > 0 {
|
||||
return "WHERE " + strings.Join(queries, " AND "), params
|
||||
}
|
||||
return "", params
|
||||
}
|
||||
|
||||
func (cr configRepository) rollback(defErr error, tx *sqlx.Tx) error {
|
||||
func (cr configRepository) rollback(content string, defErr error, tx *sqlx.Tx) error {
|
||||
if err := tx.Rollback(); err != nil {
|
||||
return errors.Wrap(defErr, errors.Wrap(errors.New("failed to rollback"), err))
|
||||
return errors.Wrap(defErr, errors.Wrap(errors.New("failed to rollback at "+content), err))
|
||||
}
|
||||
|
||||
return defErr
|
||||
}
|
||||
|
||||
func insertChannels(_ context.Context, owner string, channels []bootstrap.Channel, tx *sqlx.Tx) error {
|
||||
func insertChannels(domainID string, channels []bootstrap.Channel, tx *sqlx.Tx) error {
|
||||
if len(channels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var chans []dbChannel
|
||||
for _, ch := range channels {
|
||||
dbch, err := toDBChannel(owner, ch)
|
||||
dbch, err := toDBChannel(domainID, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chans = append(chans, dbch)
|
||||
}
|
||||
q := `INSERT INTO channels (magistrala_channel, owner, name, metadata, parent_id, description, created_at, updated_at, updated_by, status)
|
||||
VALUES (:magistrala_channel, :owner, :name, :metadata, :parent_id, :description, :created_at, :updated_at, :updated_by, :status)`
|
||||
q := `INSERT INTO channels (magistrala_channel, domain_id, name, metadata, parent_id, description, created_at, updated_at, updated_by, status)
|
||||
VALUES (:magistrala_channel, :domain_id, :name, :metadata, :parent_id, :description, :created_at, :updated_at, :updated_by, :status)`
|
||||
if _, err := tx.NamedExec(q, chans); err != nil {
|
||||
e := err
|
||||
if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgerrcode.UniqueViolation {
|
||||
@@ -544,15 +551,15 @@ func insertConnections(_ context.Context, cfg bootstrap.Config, connections []st
|
||||
return nil
|
||||
}
|
||||
|
||||
q := `INSERT INTO connections (config_id, channel_id, config_owner, channel_owner)
|
||||
VALUES (:config_id, :channel_id, :config_owner, :channel_owner)`
|
||||
q := `INSERT INTO connections (config_id, channel_id, domain_id)
|
||||
VALUES (:config_id, :channel_id, :domain_id)`
|
||||
|
||||
conns := []dbConnection{}
|
||||
for _, conn := range connections {
|
||||
dbconn := dbConnection{
|
||||
Config: cfg.ThingID,
|
||||
Channel: conn,
|
||||
ConfigOwner: cfg.Owner,
|
||||
ChannelOwner: cfg.Owner,
|
||||
Config: cfg.ThingID,
|
||||
Channel: conn,
|
||||
DomainID: cfg.DomainID,
|
||||
}
|
||||
conns = append(conns, dbconn)
|
||||
}
|
||||
@@ -561,13 +568,13 @@ func insertConnections(_ context.Context, cfg bootstrap.Config, connections []st
|
||||
return err
|
||||
}
|
||||
|
||||
func updateConnections(_ context.Context, owner, id string, connections []string, tx *sqlx.Tx) error {
|
||||
func updateConnections(domainID, id string, connections []string, tx *sqlx.Tx) error {
|
||||
if len(connections) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
q := `DELETE FROM connections
|
||||
WHERE config_id = $1 AND config_owner = $2 AND channel_owner = $2
|
||||
WHERE config_id = $1 AND domain_id = $2
|
||||
AND channel_id NOT IN ($3)`
|
||||
|
||||
var conn pgtype.TextArray
|
||||
@@ -575,7 +582,7 @@ func updateConnections(_ context.Context, owner, id string, connections []string
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := tx.Exec(q, id, owner, conn)
|
||||
res, err := tx.Exec(q, id, domainID, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -585,16 +592,15 @@ func updateConnections(_ context.Context, owner, id string, connections []string
|
||||
return err
|
||||
}
|
||||
|
||||
q = `INSERT INTO connections (config_id, channel_id, config_owner, channel_owner)
|
||||
VALUES (:config_id, :channel_id, :config_owner, :channel_owner)`
|
||||
q = `INSERT INTO connections (config_id, channel_id, domain_id)
|
||||
VALUES (:config_id, :channel_id, :domain_id)`
|
||||
|
||||
conns := []dbConnection{}
|
||||
for _, conn := range connections {
|
||||
dbconn := dbConnection{
|
||||
Config: id,
|
||||
Channel: conn,
|
||||
ConfigOwner: owner,
|
||||
ChannelOwner: owner,
|
||||
Config: id,
|
||||
Channel: conn,
|
||||
DomainID: domainID,
|
||||
}
|
||||
conns = append(conns, dbconn)
|
||||
}
|
||||
@@ -636,7 +642,7 @@ func nullTime(t time.Time) sql.NullTime {
|
||||
|
||||
type dbConfig struct {
|
||||
ThingID string `db:"magistrala_thing"`
|
||||
Owner string `db:"owner"`
|
||||
DomainID string `db:"domain_id"`
|
||||
Name sql.NullString `db:"name"`
|
||||
ClientCert sql.NullString `db:"client_cert"`
|
||||
ClientKey sql.NullString `db:"client_key"`
|
||||
@@ -651,7 +657,7 @@ type dbConfig struct {
|
||||
func toDBConfig(cfg bootstrap.Config) dbConfig {
|
||||
return dbConfig{
|
||||
ThingID: cfg.ThingID,
|
||||
Owner: cfg.Owner,
|
||||
DomainID: cfg.DomainID,
|
||||
Name: nullString(cfg.Name),
|
||||
ClientCert: nullString(cfg.ClientCert),
|
||||
ClientKey: nullString(cfg.ClientKey),
|
||||
@@ -667,7 +673,7 @@ func toDBConfig(cfg bootstrap.Config) dbConfig {
|
||||
func toConfig(dbcfg dbConfig) bootstrap.Config {
|
||||
cfg := bootstrap.Config{
|
||||
ThingID: dbcfg.ThingID,
|
||||
Owner: dbcfg.Owner,
|
||||
DomainID: dbcfg.DomainID,
|
||||
ThingKey: dbcfg.ThingKey,
|
||||
ExternalID: dbcfg.ExternalID,
|
||||
ExternalKey: dbcfg.ExternalKey,
|
||||
@@ -699,7 +705,7 @@ func toConfig(dbcfg dbConfig) bootstrap.Config {
|
||||
type dbChannel struct {
|
||||
ID string `db:"magistrala_channel"`
|
||||
Name sql.NullString `db:"name"`
|
||||
Owner sql.NullString `db:"owner"`
|
||||
DomainID sql.NullString `db:"domain_id"`
|
||||
Metadata string `db:"metadata"`
|
||||
Parent sql.NullString `db:"parent_id,omitempty"`
|
||||
Description string `db:"description,omitempty"`
|
||||
@@ -709,11 +715,11 @@ type dbChannel struct {
|
||||
Status clients.Status `db:"status"`
|
||||
}
|
||||
|
||||
func toDBChannel(owner string, ch bootstrap.Channel) (dbChannel, error) {
|
||||
func toDBChannel(domainID string, ch bootstrap.Channel) (dbChannel, error) {
|
||||
dbch := dbChannel{
|
||||
ID: ch.ID,
|
||||
Name: nullString(ch.Name),
|
||||
Owner: nullString(owner),
|
||||
DomainID: nullString(domainID),
|
||||
Parent: nullString(ch.Parent),
|
||||
Description: ch.Description,
|
||||
CreatedAt: ch.CreatedAt,
|
||||
@@ -742,8 +748,8 @@ func toChannel(dbch dbChannel) (bootstrap.Channel, error) {
|
||||
if dbch.Name.Valid {
|
||||
ch.Name = dbch.Name.String
|
||||
}
|
||||
if dbch.Owner.Valid {
|
||||
ch.Owner = dbch.Owner.String
|
||||
if dbch.DomainID.Valid {
|
||||
ch.DomainID = dbch.DomainID.String
|
||||
}
|
||||
if dbch.Parent.Valid {
|
||||
ch.Parent = dbch.Parent.String
|
||||
@@ -763,8 +769,7 @@ func toChannel(dbch dbChannel) (bootstrap.Channel, error) {
|
||||
}
|
||||
|
||||
type dbConnection struct {
|
||||
Config string `db:"config_id"`
|
||||
Channel string `db:"channel_id"`
|
||||
ConfigOwner string `db:"config_owner"`
|
||||
ChannelOwner string `db:"channel_owner"`
|
||||
Config string `db:"config_id"`
|
||||
Channel string `db:"channel_id"`
|
||||
DomainID string `db:"domain_id"`
|
||||
}
|
||||
|
||||
+175
-142
@@ -27,7 +27,7 @@ var (
|
||||
ThingKey: "mg-key",
|
||||
ExternalID: "external-id",
|
||||
ExternalKey: "external-key",
|
||||
Owner: "user@email.com",
|
||||
DomainID: testsutil.GenerateUUID(&testing.T{}),
|
||||
Channels: []bootstrap.Channel{
|
||||
{ID: "1", Name: "name 1", Metadata: map[string]interface{}{"meta": 1.0}},
|
||||
{ID: "2", Name: "name 2", Metadata: map[string]interface{}{"meta": 2.0}},
|
||||
@@ -121,38 +121,38 @@ func TestRetrieveByID(t *testing.T) {
|
||||
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
owner string
|
||||
id string
|
||||
err error
|
||||
desc string
|
||||
domainID string
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve config",
|
||||
owner: c.Owner,
|
||||
id: id,
|
||||
err: nil,
|
||||
desc: "retrieve config",
|
||||
domainID: c.DomainID,
|
||||
id: id,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve config with wrong owner",
|
||||
owner: "2",
|
||||
id: id,
|
||||
err: repoerr.ErrNotFound,
|
||||
desc: "retrieve config with wrong domain ID ",
|
||||
domainID: "2",
|
||||
id: id,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve a non-existing config",
|
||||
owner: c.Owner,
|
||||
id: nonexistentConfID.String(),
|
||||
err: repoerr.ErrNotFound,
|
||||
desc: "retrieve a non-existing config",
|
||||
domainID: c.DomainID,
|
||||
id: nonexistentConfID.String(),
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve a config with invalid ID",
|
||||
owner: c.Owner,
|
||||
id: "invalid",
|
||||
err: repoerr.ErrNotFound,
|
||||
desc: "retrieve a config with invalid ID",
|
||||
domainID: c.DomainID,
|
||||
id: "invalid",
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
_, err := repo.RetrieveByID(context.Background(), tc.owner, tc.id)
|
||||
_, err := repo.RetrieveByID(context.Background(), tc.domainID, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
@@ -162,6 +162,8 @@ func TestRetrieveAll(t *testing.T) {
|
||||
err := deleteChannels(context.Background(), repo)
|
||||
require.Nil(t, err, "Channels cleanup expected to succeed.")
|
||||
|
||||
thingIDs := make([]string, numConfigs)
|
||||
|
||||
for i := 0; i < numConfigs; i++ {
|
||||
c := config
|
||||
|
||||
@@ -173,6 +175,8 @@ func TestRetrieveAll(t *testing.T) {
|
||||
c.ThingID = uid.String()
|
||||
c.ThingKey = uid.String()
|
||||
|
||||
thingIDs[i] = c.ThingID
|
||||
|
||||
if i%2 == 0 {
|
||||
c.State = bootstrap.Active
|
||||
}
|
||||
@@ -184,55 +188,85 @@ func TestRetrieveAll(t *testing.T) {
|
||||
_, err = repo.Save(context.Background(), c, channels)
|
||||
require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
owner string
|
||||
offset uint64
|
||||
limit uint64
|
||||
filter bootstrap.Filter
|
||||
size int
|
||||
desc string
|
||||
domainID string
|
||||
thingID []string
|
||||
offset uint64
|
||||
limit uint64
|
||||
filter bootstrap.Filter
|
||||
size int
|
||||
}{
|
||||
{
|
||||
desc: "retrieve all",
|
||||
owner: config.Owner,
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
size: numConfigs,
|
||||
desc: "retrieve all configs",
|
||||
domainID: config.DomainID,
|
||||
thingID: []string{},
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
size: numConfigs,
|
||||
},
|
||||
{
|
||||
desc: "retrieve subset",
|
||||
owner: config.Owner,
|
||||
offset: 5,
|
||||
limit: uint64(numConfigs - 5),
|
||||
size: numConfigs - 5,
|
||||
desc: "retrieve a subset of configs",
|
||||
domainID: config.DomainID,
|
||||
thingID: []string{},
|
||||
offset: 5,
|
||||
limit: uint64(numConfigs - 5),
|
||||
size: numConfigs - 5,
|
||||
},
|
||||
{
|
||||
desc: "retrieve wrong owner",
|
||||
owner: "2",
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
size: 0,
|
||||
desc: "retrieve with wrong domain ID ",
|
||||
domainID: "2",
|
||||
thingID: []string{},
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
size: 0,
|
||||
},
|
||||
{
|
||||
desc: "retrieve all active",
|
||||
owner: config.Owner,
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
filter: bootstrap.Filter{FullMatch: map[string]string{"state": bootstrap.Active.String()}},
|
||||
size: numConfigs / 2,
|
||||
desc: "retrieve all active configs ",
|
||||
domainID: config.DomainID,
|
||||
thingID: []string{},
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
filter: bootstrap.Filter{FullMatch: map[string]string{"state": bootstrap.Active.String()}},
|
||||
size: numConfigs / 2,
|
||||
},
|
||||
{
|
||||
desc: "retrieve search by name",
|
||||
owner: config.Owner,
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
|
||||
size: 1,
|
||||
desc: "retrieve all with partial match filter",
|
||||
domainID: config.DomainID,
|
||||
thingID: []string{},
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
|
||||
size: 1,
|
||||
},
|
||||
{
|
||||
desc: "retrieve search by name",
|
||||
domainID: config.DomainID,
|
||||
thingID: []string{},
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
|
||||
size: 1,
|
||||
},
|
||||
{
|
||||
desc: "retrieve by valid thingIDs",
|
||||
domainID: config.DomainID,
|
||||
thingID: thingIDs,
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
size: 10,
|
||||
},
|
||||
{
|
||||
desc: "retrieve by non-existing thingID",
|
||||
domainID: config.DomainID,
|
||||
thingID: []string{"non-existing"},
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
size: 0,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
ret := repo.RetrieveAll(context.Background(), tc.owner, tc.filter, tc.offset, tc.limit)
|
||||
ret := repo.RetrieveAll(context.Background(), tc.domainID, tc.thingID, tc.filter, tc.offset, tc.limit)
|
||||
size := len(ret.Configs)
|
||||
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
|
||||
}
|
||||
@@ -295,8 +329,8 @@ func TestUpdate(t *testing.T) {
|
||||
c.Content = "new content"
|
||||
c.Name = "new name"
|
||||
|
||||
wrongOwner := c
|
||||
wrongOwner.Owner = "3"
|
||||
wrongDomainID := c
|
||||
wrongDomainID.DomainID = "3"
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -305,8 +339,8 @@ func TestUpdate(t *testing.T) {
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update with wrong owner",
|
||||
config: wrongOwner,
|
||||
desc: "update with wrong domainID ",
|
||||
config: wrongDomainID,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
@@ -340,13 +374,13 @@ func TestUpdateCert(t *testing.T) {
|
||||
c.Content = "new content"
|
||||
c.Name = "new name"
|
||||
|
||||
wrongOwner := c
|
||||
wrongOwner.Owner = "3"
|
||||
wrongDomainID := c
|
||||
wrongDomainID.DomainID = "3"
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
thingID string
|
||||
owner string
|
||||
domainID string
|
||||
cert string
|
||||
certKey string
|
||||
ca string
|
||||
@@ -354,34 +388,34 @@ func TestUpdateCert(t *testing.T) {
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update with wrong owner",
|
||||
desc: "update with wrong domain ID ",
|
||||
thingID: "",
|
||||
cert: "cert",
|
||||
certKey: "certKey",
|
||||
ca: "",
|
||||
owner: "wrong",
|
||||
domainID: wrongDomainID.DomainID,
|
||||
expectedConfig: bootstrap.Config{},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update a config",
|
||||
thingID: c.ThingID,
|
||||
cert: "cert",
|
||||
certKey: "certKey",
|
||||
ca: "ca",
|
||||
owner: c.Owner,
|
||||
desc: "update a config",
|
||||
thingID: c.ThingID,
|
||||
cert: "cert",
|
||||
certKey: "certKey",
|
||||
ca: "ca",
|
||||
domainID: c.DomainID,
|
||||
expectedConfig: bootstrap.Config{
|
||||
ThingID: c.ThingID,
|
||||
ClientCert: "cert",
|
||||
CACert: "ca",
|
||||
ClientKey: "certKey",
|
||||
Owner: c.Owner,
|
||||
DomainID: c.DomainID,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
cfg, err := repo.UpdateCert(context.Background(), tc.owner, tc.thingID, tc.cert, tc.certKey, tc.ca)
|
||||
cfg, err := repo.UpdateCert(context.Background(), tc.domainID, tc.thingID, tc.cert, tc.certKey, tc.ca)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.expectedConfig, cfg, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.expectedConfig, cfg))
|
||||
}
|
||||
@@ -415,7 +449,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
owner string
|
||||
domainID string
|
||||
id string
|
||||
channels []bootstrap.Channel
|
||||
connections []string
|
||||
@@ -423,7 +457,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
desc: "update connections of non-existing config",
|
||||
owner: config.Owner,
|
||||
domainID: config.DomainID,
|
||||
id: "unknown",
|
||||
channels: nil,
|
||||
connections: []string{channels[1]},
|
||||
@@ -431,7 +465,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "update connections",
|
||||
owner: config.Owner,
|
||||
domainID: config.DomainID,
|
||||
id: c.ThingID,
|
||||
channels: nil,
|
||||
connections: []string{channels[1]},
|
||||
@@ -439,7 +473,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "update connections with existing channels",
|
||||
owner: config.Owner,
|
||||
domainID: config.DomainID,
|
||||
id: c2,
|
||||
channels: nil,
|
||||
connections: channels,
|
||||
@@ -447,7 +481,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "update connections no channels",
|
||||
owner: config.Owner,
|
||||
domainID: config.DomainID,
|
||||
id: c.ThingID,
|
||||
channels: nil,
|
||||
connections: nil,
|
||||
@@ -455,7 +489,7 @@ func TestUpdateConnections(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := repo.UpdateConnections(context.Background(), tc.owner, tc.id, tc.channels, tc.connections)
|
||||
err := repo.UpdateConnections(context.Background(), tc.domainID, tc.id, tc.channels, tc.connections)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
@@ -479,10 +513,10 @@ func TestRemove(t *testing.T) {
|
||||
// Removal works the same for both existing and non-existing
|
||||
// (removed) config
|
||||
for i := 0; i < 2; i++ {
|
||||
err := repo.Remove(context.Background(), c.Owner, id)
|
||||
err := repo.Remove(context.Background(), c.DomainID, id)
|
||||
assert.Nil(t, err, fmt.Sprintf("%d: failed to remove config due to: %s", i, err))
|
||||
|
||||
_, err = repo.RetrieveByID(context.Background(), c.Owner, id)
|
||||
_, err = repo.RetrieveByID(context.Background(), c.DomainID, id)
|
||||
assert.True(t, errors.Contains(err, repoerr.ErrNotFound), fmt.Sprintf("%d: expected %s got %s", i, repoerr.ErrNotFound, err))
|
||||
}
|
||||
}
|
||||
@@ -504,41 +538,41 @@ func TestChangeState(t *testing.T) {
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
owner string
|
||||
id string
|
||||
state bootstrap.State
|
||||
err error
|
||||
desc string
|
||||
domainID string
|
||||
id string
|
||||
state bootstrap.State
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "change state with wrong owner",
|
||||
id: saved,
|
||||
owner: "2",
|
||||
err: repoerr.ErrNotFound,
|
||||
desc: "change state with wrong domain ID ",
|
||||
id: saved,
|
||||
domainID: "2",
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "change state with wrong id",
|
||||
id: "wrong",
|
||||
owner: c.Owner,
|
||||
err: repoerr.ErrNotFound,
|
||||
desc: "change state with wrong id",
|
||||
id: "wrong",
|
||||
domainID: c.DomainID,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "change state to Active",
|
||||
id: saved,
|
||||
owner: c.Owner,
|
||||
state: bootstrap.Active,
|
||||
err: nil,
|
||||
desc: "change state to Active",
|
||||
id: saved,
|
||||
domainID: c.DomainID,
|
||||
state: bootstrap.Active,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "change state to Inactive",
|
||||
id: saved,
|
||||
owner: c.Owner,
|
||||
state: bootstrap.Inactive,
|
||||
err: nil,
|
||||
desc: "change state to Inactive",
|
||||
id: saved,
|
||||
domainID: c.DomainID,
|
||||
state: bootstrap.Inactive,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := repo.ChangeState(context.Background(), tc.owner, tc.id, tc.state)
|
||||
err := repo.ChangeState(context.Background(), tc.domainID, tc.id, tc.state)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
@@ -564,31 +598,31 @@ func TestListExisting(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
owner string
|
||||
domainID string
|
||||
connections []string
|
||||
existing []bootstrap.Channel
|
||||
}{
|
||||
{
|
||||
desc: "list all existing channels",
|
||||
owner: c.Owner,
|
||||
domainID: c.DomainID,
|
||||
connections: channels,
|
||||
existing: chs,
|
||||
},
|
||||
{
|
||||
desc: "list a subset of existing channels",
|
||||
owner: c.Owner,
|
||||
domainID: c.DomainID,
|
||||
connections: []string{channels[0], "5"},
|
||||
existing: []bootstrap.Channel{chs[0]},
|
||||
},
|
||||
{
|
||||
desc: "list a subset of existing channels empty",
|
||||
owner: c.Owner,
|
||||
domainID: c.DomainID,
|
||||
connections: []string{"5", "6"},
|
||||
existing: []bootstrap.Channel{},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
existing, err := repo.ListExisting(context.Background(), tc.owner, tc.connections)
|
||||
existing, err := repo.ListExisting(context.Background(), tc.domainID, tc.connections)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error: %s", tc.desc, err))
|
||||
assert.ElementsMatch(t, tc.existing, existing, fmt.Sprintf("%s: Got non-matching elements.", tc.desc))
|
||||
}
|
||||
@@ -640,7 +674,7 @@ func TestUpdateChannel(t *testing.T) {
|
||||
err = repo.UpdateChannel(context.Background(), update)
|
||||
assert.Nil(t, err, fmt.Sprintf("updating config expected to succeed: %s.\n", err))
|
||||
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
|
||||
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
|
||||
var retreved bootstrap.Channel
|
||||
for _, c := range cfg.Channels {
|
||||
@@ -649,7 +683,7 @@ func TestUpdateChannel(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
update.Owner = retreved.Owner
|
||||
update.DomainID = retreved.DomainID
|
||||
assert.Equal(t, update, retreved, fmt.Sprintf("expected %s, go %s", update, retreved))
|
||||
}
|
||||
|
||||
@@ -671,7 +705,7 @@ func TestRemoveChannel(t *testing.T) {
|
||||
err = repo.RemoveChannel(context.Background(), c.Channels[0].ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
|
||||
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
|
||||
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
|
||||
assert.NotContains(t, cfg.Channels, c.Channels[0], fmt.Sprintf("expected to remove channel %s from %s", c.Channels[0], cfg.Channels))
|
||||
}
|
||||
@@ -704,38 +738,37 @@ func TestConnectThing(t *testing.T) {
|
||||
|
||||
emptyThing := c
|
||||
emptyThing.ThingID = ""
|
||||
emptyThing.ThingKey = ""
|
||||
emptyThing.ExternalID = ""
|
||||
emptyThing.ExternalKey = ""
|
||||
emptyThing.Channels = []bootstrap.Channel{}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
owner string
|
||||
domainID string
|
||||
id string
|
||||
state bootstrap.State
|
||||
channels []bootstrap.Channel
|
||||
connections []string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "connect disconnected thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: saved,
|
||||
state: bootstrap.Inactive,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "connect already connected thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: connectedThing.ThingID,
|
||||
state: connectedThing.State,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "connect non-existent thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: wrongID,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
@@ -743,7 +776,7 @@ func TestConnectThing(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "connect random thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: randomThing.ThingID,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
@@ -751,7 +784,7 @@ func TestConnectThing(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "connect empty thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: emptyThing.ThingID,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
@@ -763,7 +796,7 @@ func TestConnectThing(t *testing.T) {
|
||||
if i == 0 {
|
||||
err = repo.ConnectThing(context.Background(), ch.ID, tc.id)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: Expected error: %s, got: %s.\n", tc.desc, tc.err, err))
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
|
||||
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
|
||||
assert.Equal(t, cfg.State, bootstrap.Active, fmt.Sprintf("expected to be active when a connection is added from %s", cfg))
|
||||
} else {
|
||||
@@ -771,7 +804,7 @@ func TestConnectThing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
|
||||
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
|
||||
assert.Equal(t, cfg.State, bootstrap.Active, fmt.Sprintf("expected to be active when a connection is added from %s", cfg))
|
||||
}
|
||||
@@ -805,57 +838,57 @@ func TestDisconnectThing(t *testing.T) {
|
||||
|
||||
emptyThing := c
|
||||
emptyThing.ThingID = ""
|
||||
emptyThing.ThingKey = ""
|
||||
emptyThing.ExternalID = ""
|
||||
emptyThing.ExternalKey = ""
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
owner string
|
||||
domainID string
|
||||
id string
|
||||
state bootstrap.State
|
||||
channels []bootstrap.Channel
|
||||
connections []string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "disconnect connected thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: connectedThing.ThingID,
|
||||
state: connectedThing.State,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "disconnect already disconnected thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: saved,
|
||||
state: bootstrap.Inactive,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "disconnect invalid thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: wrongID,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
err: repoerr.ErrNotFound,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "disconnect random thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: randomThing.ThingID,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
err: repoerr.ErrNotFound,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "disconnect empty thing",
|
||||
owner: config.Owner,
|
||||
domainID: c.DomainID,
|
||||
id: emptyThing.ThingID,
|
||||
channels: c.Channels,
|
||||
connections: channels,
|
||||
err: repoerr.ErrNotFound,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -864,11 +897,11 @@ func TestDisconnectThing(t *testing.T) {
|
||||
err = repo.DisconnectThing(context.Background(), ch.ID, tc.id)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: Expected error: %s, got: %s.\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.Owner, c.ThingID)
|
||||
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
|
||||
assert.Equal(t, cfg.State, bootstrap.Inactive, fmt.Sprintf("expected to be inactive when a connection is removed from %s", cfg))
|
||||
cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ThingID)
|
||||
assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
|
||||
assert.Equal(t, cfg.State, bootstrap.Inactive, fmt.Sprintf("expected to be inactive when a connection is removed from %s", cfg))
|
||||
}
|
||||
}
|
||||
|
||||
func deleteChannels(ctx context.Context, repo bootstrap.ConfigRepository) error {
|
||||
|
||||
@@ -83,6 +83,26 @@ func Migration() *migrate.MemoryMigrationSource {
|
||||
`ALTER TABLE IF EXISTS channels RENAME COLUMN mainflux_channel TO magistrala_channel`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_5",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS configs RENAME COLUMN owner TO domain_id`,
|
||||
`ALTER TABLE IF EXISTS channels RENAME COLUMN owner TO domain_id`,
|
||||
`ALTER TABLE IF EXISTS configs ADD CONSTRAINT configs_name_domain_id_key UNIQUE (name, domain_id)`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_6",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS connections DROP CONSTRAINT IF EXISTS connections_pkey`,
|
||||
`ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS channel_owner`,
|
||||
`ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS config_owner`,
|
||||
`ALTER TABLE IF EXISTS connections ADD COLUMN IF NOT EXISTS domain_id VARCHAR(256) NOT NULL`,
|
||||
`ALTER TABLE IF EXISTS connections ADD CONSTRAINT connections_pkey PRIMARY KEY (channel_id, config_id, domain_id)`,
|
||||
`ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (channel_id, domain_id) REFERENCES channels (magistrala_channel, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`,
|
||||
`ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (config_id, domain_id) REFERENCES configs (magistrala_thing, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+133
-25
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
"github.com/absmach/magistrala/auth"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
@@ -34,6 +35,9 @@ var (
|
||||
// ErrAddBootstrap indicates error in adding bootstrap configuration.
|
||||
ErrAddBootstrap = errors.New("failed to add bootstrap configuration")
|
||||
|
||||
// ErrNotInSameDomain indicates entities are not in the same domain.
|
||||
errNotInSameDomain = errors.New("entities are not in the same domain")
|
||||
|
||||
errUpdateConnections = errors.New("failed to update connections")
|
||||
errRemoveBootstrap = errors.New("failed to remove bootstrap configuration")
|
||||
errChangeState = errors.New("failed to change state of bootstrap configuration")
|
||||
@@ -82,7 +86,7 @@ type Service interface {
|
||||
// Bootstrap returns Config to the Thing with provided external ID using external key.
|
||||
Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (Config, error)
|
||||
|
||||
// ChangeState changes state of the Thing with given ID and owner.
|
||||
// ChangeState changes state of the Thing with given thing ID and domain ID.
|
||||
ChangeState(ctx context.Context, token, id string, state State) error
|
||||
|
||||
// Methods RemoveConfig, UpdateChannel, and RemoveChannel are used as
|
||||
@@ -123,26 +127,29 @@ type bootstrapService struct {
|
||||
}
|
||||
|
||||
// New returns new Bootstrap service.
|
||||
func New(auth magistrala.AuthServiceClient, configs ConfigRepository, sdk mgsdk.SDK, encKey []byte, idp magistrala.IDProvider) Service {
|
||||
func New(uauth magistrala.AuthServiceClient, configs ConfigRepository, sdk mgsdk.SDK, encKey []byte, idp magistrala.IDProvider) Service {
|
||||
return &bootstrapService{
|
||||
configs: configs,
|
||||
sdk: sdk,
|
||||
auth: auth,
|
||||
auth: uauth,
|
||||
encKey: encKey,
|
||||
idProvider: idp,
|
||||
}
|
||||
}
|
||||
|
||||
func (bs bootstrapService) Add(ctx context.Context, token string, cfg Config) (Config, error) {
|
||||
owner, err := bs.identify(ctx, token)
|
||||
user, err := bs.identify(ctx, token)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if _, err := bs.authorize(ctx, "", auth.UsersKind, user.GetId(), auth.MembershipPermission, auth.DomainType, user.GetDomainId()); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
toConnect := bs.toIDList(cfg.Channels)
|
||||
|
||||
// Check if channels exist. This is the way to prevent fetching channels that already exist.
|
||||
existing, err := bs.configs.ListExisting(ctx, owner, toConnect)
|
||||
existing, err := bs.configs.ListExisting(ctx, user.GetDomainId(), toConnect)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errCheckChannels, err)
|
||||
}
|
||||
@@ -158,8 +165,14 @@ func (bs bootstrapService) Add(ctx context.Context, token string, cfg Config) (C
|
||||
return Config{}, errors.Wrap(errThingNotFound, err)
|
||||
}
|
||||
|
||||
for _, channel := range cfg.Channels {
|
||||
if channel.DomainID != mgThing.DomainID {
|
||||
return Config{}, errors.Wrap(svcerr.ErrMalformedEntity, errNotInSameDomain)
|
||||
}
|
||||
}
|
||||
|
||||
cfg.ThingID = mgThing.ID
|
||||
cfg.Owner = owner
|
||||
cfg.DomainID = user.GetDomainId()
|
||||
cfg.State = Inactive
|
||||
cfg.ThingKey = mgThing.Credentials.Secret
|
||||
|
||||
@@ -182,11 +195,14 @@ func (bs bootstrapService) Add(ctx context.Context, token string, cfg Config) (C
|
||||
}
|
||||
|
||||
func (bs bootstrapService) View(ctx context.Context, token, id string) (Config, error) {
|
||||
owner, err := bs.identify(ctx, token)
|
||||
user, err := bs.identify(ctx, token)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, owner, id)
|
||||
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.ViewPermission, auth.ThingType, id); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, user.GetDomainId(), id)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
@@ -194,12 +210,15 @@ func (bs bootstrapService) View(ctx context.Context, token, id string) (Config,
|
||||
}
|
||||
|
||||
func (bs bootstrapService) Update(ctx context.Context, token string, cfg Config) error {
|
||||
owner, err := bs.identify(ctx, token)
|
||||
user, err := bs.identify(ctx, token)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.EditPermission, auth.ThingType, cfg.ThingID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Owner = owner
|
||||
cfg.DomainID = user.GetDomainId()
|
||||
if err = bs.configs.Update(ctx, cfg); err != nil {
|
||||
return errors.Wrap(errUpdateConnections, err)
|
||||
}
|
||||
@@ -207,11 +226,15 @@ func (bs bootstrapService) Update(ctx context.Context, token string, cfg Config)
|
||||
}
|
||||
|
||||
func (bs bootstrapService) UpdateCert(ctx context.Context, token, thingID, clientCert, clientKey, caCert string) (Config, error) {
|
||||
owner, err := bs.identify(ctx, token)
|
||||
user, err := bs.identify(ctx, token)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
cfg, err := bs.configs.UpdateCert(ctx, owner, thingID, clientCert, clientKey, caCert)
|
||||
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.EditPermission, auth.ThingType, thingID); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
cfg, err := bs.configs.UpdateCert(ctx, user.GetDomainId(), thingID, clientCert, clientKey, caCert)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errUpdateCert, err)
|
||||
}
|
||||
@@ -219,12 +242,16 @@ func (bs bootstrapService) UpdateCert(ctx context.Context, token, thingID, clien
|
||||
}
|
||||
|
||||
func (bs bootstrapService) UpdateConnections(ctx context.Context, token, id string, connections []string) error {
|
||||
owner, err := bs.identify(ctx, token)
|
||||
user, err := bs.identify(ctx, token)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, owner, id)
|
||||
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.EditPermission, auth.ThingType, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, user.GetDomainId(), id)
|
||||
if err != nil {
|
||||
return errors.Wrap(errUpdateConnections, err)
|
||||
}
|
||||
@@ -232,7 +259,7 @@ func (bs bootstrapService) UpdateConnections(ctx context.Context, token, id stri
|
||||
add, remove := bs.updateList(cfg, connections)
|
||||
|
||||
// Check if channels exist. This is the way to prevent fetching channels that already exist.
|
||||
existing, err := bs.configs.ListExisting(ctx, owner, connections)
|
||||
existing, err := bs.configs.ListExisting(ctx, user.GetDomainId(), connections)
|
||||
if err != nil {
|
||||
return errors.Wrap(errUpdateConnections, err)
|
||||
}
|
||||
@@ -268,26 +295,83 @@ func (bs bootstrapService) UpdateConnections(ctx context.Context, token, id stri
|
||||
return ErrThings
|
||||
}
|
||||
}
|
||||
if err := bs.configs.UpdateConnections(ctx, owner, id, channels, connections); err != nil {
|
||||
if err := bs.configs.UpdateConnections(ctx, user.GetDomainId(), id, channels, connections); err != nil {
|
||||
return errors.Wrap(errUpdateConnections, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) listClientIDs(ctx context.Context, userID string) ([]string, error) {
|
||||
tids, err := bs.auth.ListAllObjects(ctx, &magistrala.ListObjectsReq{
|
||||
SubjectType: auth.UserType,
|
||||
Subject: userID,
|
||||
Permission: auth.ViewPermission,
|
||||
ObjectType: auth.ThingType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
}
|
||||
return tids.Policies, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) checkSuperAdmin(ctx context.Context, userID string) error {
|
||||
res, err := bs.auth.Authorize(ctx, &magistrala.AuthorizeReq{
|
||||
SubjectType: auth.UserType,
|
||||
Subject: userID,
|
||||
Permission: auth.AdminPermission,
|
||||
ObjectType: auth.PlatformType,
|
||||
Object: auth.MagistralaObject,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !res.Authorized {
|
||||
return errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) List(ctx context.Context, token string, filter Filter, offset, limit uint64) (ConfigsPage, error) {
|
||||
owner, err := bs.identify(ctx, token)
|
||||
user, err := bs.identify(ctx, token)
|
||||
if err != nil {
|
||||
return ConfigsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
return bs.configs.RetrieveAll(ctx, owner, filter, offset, limit), nil
|
||||
|
||||
if err := bs.checkSuperAdmin(ctx, user.GetId()); err == nil {
|
||||
return bs.configs.RetrieveAll(ctx, user.GetDomainId(), []string{}, filter, offset, limit), nil
|
||||
}
|
||||
|
||||
if _, err := bs.authorize(ctx, "", auth.UsersKind, user.GetId(), auth.AdminPermission, auth.DomainType, user.GetDomainId()); err == nil {
|
||||
return bs.configs.RetrieveAll(ctx, user.GetDomainId(), []string{}, filter, offset, limit), nil
|
||||
}
|
||||
|
||||
// Handle non-admin users
|
||||
thingIDs, err := bs.listClientIDs(ctx, user.GetId())
|
||||
if err != nil {
|
||||
return ConfigsPage{}, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
}
|
||||
|
||||
if len(thingIDs) == 0 {
|
||||
return ConfigsPage{
|
||||
Total: 0,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
Configs: []Config{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return bs.configs.RetrieveAll(ctx, user.GetDomainId(), thingIDs, filter, offset, limit), nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) Remove(ctx context.Context, token, id string) error {
|
||||
owner, err := bs.identify(ctx, token)
|
||||
user, err := bs.identify(ctx, token)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if err := bs.configs.Remove(ctx, owner, id); err != nil {
|
||||
if _, err := bs.authorize(ctx, user.GetDomainId(), auth.UsersKind, user.GetId(), auth.DeletePermission, auth.ThingType, id); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bs.configs.Remove(ctx, user.GetDomainId(), id); err != nil {
|
||||
return errors.Wrap(errRemoveBootstrap, err)
|
||||
}
|
||||
return nil
|
||||
@@ -313,12 +397,12 @@ func (bs bootstrapService) Bootstrap(ctx context.Context, externalKey, externalI
|
||||
}
|
||||
|
||||
func (bs bootstrapService) ChangeState(ctx context.Context, token, id string, state State) error {
|
||||
owner, err := bs.identify(ctx, token)
|
||||
user, err := bs.identify(ctx, token)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, owner, id)
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, user.GetDomainId(), id)
|
||||
if err != nil {
|
||||
return errors.Wrap(errChangeState, err)
|
||||
}
|
||||
@@ -352,7 +436,7 @@ func (bs bootstrapService) ChangeState(ctx context.Context, token, id string, st
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := bs.configs.ChangeState(ctx, owner, id, state); err != nil {
|
||||
if err := bs.configs.ChangeState(ctx, user.GetDomainId(), id, state); err != nil {
|
||||
return errors.Wrap(errChangeState, err)
|
||||
}
|
||||
return nil
|
||||
@@ -393,13 +477,36 @@ func (bs bootstrapService) DisconnectThingHandler(ctx context.Context, channelID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) identify(ctx context.Context, token string) (string, error) {
|
||||
func (bs bootstrapService) identify(ctx context.Context, token string) (*magistrala.IdentityRes, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := bs.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
return nil, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if res.GetId() == "" || res.GetDomainId() == "" {
|
||||
return nil, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) authorize(ctx context.Context, domainID, subjKind, subj, perm, objType, obj string) (string, error) {
|
||||
req := &magistrala.AuthorizeReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.UserType,
|
||||
SubjectKind: subjKind,
|
||||
Subject: subj,
|
||||
Permission: perm,
|
||||
ObjectType: objType,
|
||||
Object: obj,
|
||||
}
|
||||
res, err := bs.auth.Authorize(ctx, req)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
if !res.GetAuthorized() {
|
||||
return "", errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
|
||||
return res.GetId(), nil
|
||||
@@ -451,6 +558,7 @@ func (bs bootstrapService) connectionChannels(channels, existing []string, token
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Metadata: ch.Metadata,
|
||||
DomainID: ch.DomainID,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+834
-312
File diff suppressed because it is too large
Load Diff
@@ -27,7 +27,7 @@ func New(svc bootstrap.Service, tracer trace.Tracer) bootstrap.Service {
|
||||
func (tm *tracingMiddleware) Add(ctx context.Context, token string, cfg bootstrap.Config) (bootstrap.Config, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_register_client", trace.WithAttributes(
|
||||
attribute.String("thing_id", cfg.ThingID),
|
||||
attribute.String("owner", cfg.Owner),
|
||||
attribute.String("domain_id ", cfg.DomainID),
|
||||
attribute.String("name", cfg.Name),
|
||||
attribute.String("external_id", cfg.ExternalID),
|
||||
attribute.String("content", cfg.Content),
|
||||
@@ -54,7 +54,7 @@ func (tm *tracingMiddleware) Update(ctx context.Context, token string, cfg boots
|
||||
attribute.String("name", cfg.Name),
|
||||
attribute.String("content", cfg.Content),
|
||||
attribute.String("thing_id", cfg.ThingID),
|
||||
attribute.String("owner", cfg.Owner),
|
||||
attribute.String("domain_id ", cfg.DomainID),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user