NOISSUE - Update bootstrap content format, update profile method and add profile search (#3515)
Property Based Tests / api-test (push) Has been cancelled
Continuous Delivery / lint-and-build (push) Has been cancelled
Deploy GitHub Pages / swagger-ui (push) Has been cancelled
CI Pipeline / Lint Proto (push) Has been cancelled
Continuous Delivery / Build and Push Docker Images (push) Has been cancelled
CI Pipeline / lint-and-build (push) Has been cancelled
CI Pipeline / Test ${{ matrix.module }} (push) Has been cancelled
CI Pipeline / Upload Coverage (push) Has been cancelled
CI Pipeline / Detect Changes (push) Has been cancelled

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2026-05-19 10:02:45 +03:00
committed by GitHub
parent f380c8d360
commit 683809dc6b
26 changed files with 952 additions and 286 deletions
+19 -4
View File
@@ -353,13 +353,27 @@ func renderPreviewEndpoint(svc bootstrap.Service) endpoint.Endpoint {
} }
cfg := req.Config cfg := req.Config
bindings := req.Bindings
if req.ConfigID != "" {
stored, err := svc.View(ctx, session, req.ConfigID)
if err != nil {
return nil, err
}
cfg = stored
bindings, err = svc.ListBindings(ctx, session, req.ConfigID)
if err != nil {
return nil, err
}
}
cfg.DomainID = session.DomainID cfg.DomainID = session.DomainID
cfg.ProfileID = p.ID cfg.ProfileID = p.ID
if cfg.RenderContext == nil { if cfg.RenderContext == nil {
cfg.RenderContext = req.RenderContext cfg.RenderContext = req.RenderContext
} }
rendered, err := bootstrap.NewRenderer().Render(p, cfg, req.Bindings) rendered, err := bootstrap.NewRenderer().Render(p, cfg, bindings)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -379,10 +393,11 @@ func updateProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return nil, svcerr.ErrAuthorization return nil, svcerr.ErrAuthorization
} }
req.Profile.ID = req.profileID req.Profile.ID = req.profileID
if err := svc.UpdateProfile(ctx, session, req.Profile); err != nil { updated, err := svc.UpdateProfile(ctx, session, req.Profile)
if err != nil {
return nil, err return nil, err
} }
return profileRes{Profile: req.Profile}, nil return profileRes{Profile: updated}, nil
} }
} }
@@ -413,7 +428,7 @@ func listProfilesEndpoint(svc bootstrap.Service) endpoint.Endpoint {
if !ok { if !ok {
return nil, svcerr.ErrAuthorization return nil, svcerr.ErrAuthorization
} }
page, err := svc.ListProfiles(ctx, session, req.offset, req.limit) page, err := svc.ListProfiles(ctx, session, req.offset, req.limit, req.name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+214 -44
View File
@@ -1182,7 +1182,7 @@ func TestUploadProfile(t *testing.T) {
saved := bootstrap.Profile{ saved := bootstrap.Profile{
ID: testsutil.GenerateUUID(t), ID: testsutil.GenerateUUID(t),
Name: "gateway", Name: "gateway",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
ContentTemplate: "{{ .Device.ID }}", ContentTemplate: "{{ .Device.ID }}",
} }
@@ -1195,30 +1195,60 @@ func TestUploadProfile(t *testing.T) {
{ {
desc: "upload JSON profile", desc: "upload JSON profile",
contentType: "application/json", contentType: "application/json",
body: `{"name":"gateway","template_format":"go-template","content_template":"{{ .Device.ID }}"}`, body: `{"name":"gateway","content_format":"go-template","content_template":"{{ .Device.ID }}"}`,
profile: bootstrap.Profile{ profile: bootstrap.Profile{
Name: "gateway", Name: "gateway",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
ContentTemplate: "{{ .Device.ID }}", ContentTemplate: "{{ .Device.ID }}",
}, },
}, },
{ {
desc: "upload YAML profile", desc: "upload YAML profile",
contentType: "application/yaml", contentType: "application/yaml",
body: "name: gateway\ntemplate_format: go-template\ncontent_template: '{{ .Device.ID }}'\n", body: "name: gateway\ncontent_format: go-template\ncontent_template: '{{ .Device.ID }}'\n",
profile: bootstrap.Profile{ profile: bootstrap.Profile{
Name: "gateway", Name: "gateway",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
ContentTemplate: "{{ .Device.ID }}", ContentTemplate: "{{ .Device.ID }}",
}, },
}, },
{ {
desc: "upload TOML profile", desc: "upload TOML profile",
contentType: "application/toml", contentType: "application/toml",
body: "name = 'gateway'\ntemplate_format = 'go-template'\ncontent_template = '{{ .Device.ID }}'\n", body: "name = 'gateway'\ncontent_format = 'go-template'\ncontent_template = '{{ .Device.ID }}'\n",
profile: bootstrap.Profile{ profile: bootstrap.Profile{
Name: "gateway", Name: "gateway",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
ContentTemplate: "{{ .Device.ID }}",
},
},
{
desc: "upload JSON profile without content_format infers json",
contentType: "application/json",
body: `{"name":"gateway","content_template":"{{ .Device.ID }}"}`,
profile: bootstrap.Profile{
Name: "gateway",
ContentFormat: bootstrap.ContentFormatJSON,
ContentTemplate: "{{ .Device.ID }}",
},
},
{
desc: "upload YAML profile without content_format infers yaml",
contentType: "application/yaml",
body: "name: gateway\ncontent_template: '{{ .Device.ID }}'\n",
profile: bootstrap.Profile{
Name: "gateway",
ContentFormat: bootstrap.ContentFormatYAML,
ContentTemplate: "{{ .Device.ID }}",
},
},
{
desc: "upload TOML profile without content_format infers toml",
contentType: "application/toml",
body: "name = 'gateway'\ncontent_template = '{{ .Device.ID }}'\n",
profile: bootstrap.Profile{
Name: "gateway",
ContentFormat: bootstrap.ContentFormatTOML,
ContentTemplate: "{{ .Device.ID }}", ContentTemplate: "{{ .Device.ID }}",
}, },
}, },
@@ -1246,6 +1276,83 @@ func TestUploadProfile(t *testing.T) {
} }
} }
func TestListProfiles(t *testing.T) {
bs, svc, auth := newBootstrapServer()
defer bs.Close()
session := smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
path := fmt.Sprintf("%s/%s/clients/bootstrap/profiles", bs.URL, domainID)
profiles := []bootstrap.Profile{
{ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "gateway-profile"},
{ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "sensor-profile"},
}
fullPage := bootstrap.ProfilesPage{Total: 2, Offset: 0, Limit: 10, Profiles: profiles}
filteredPage := bootstrap.ProfilesPage{Total: 1, Offset: 0, Limit: 10, Profiles: profiles[:1]}
cases := []struct {
desc string
token string
session smqauthn.Session
url string
name string
svcPage bootstrap.ProfilesPage
svcErr error
authenticateErr error
status int
}{
{
desc: "list profiles successfully",
token: validToken,
session: session,
url: fmt.Sprintf("%s?offset=0&limit=10", path),
svcPage: fullPage,
status: http.StatusOK,
},
{
desc: "list profiles filtered by name",
token: validToken,
session: session,
url: fmt.Sprintf("%s?offset=0&limit=10&name=gateway-profile", path),
name: "gateway-profile",
svcPage: filteredPage,
status: http.StatusOK,
},
{
desc: "list profiles with invalid token",
token: invalidToken,
url: fmt.Sprintf("%s?offset=0&limit=10", path),
authenticateErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
},
{
desc: "list profiles with limit exceeding max",
token: validToken,
session: session,
url: fmt.Sprintf("%s?offset=0&limit=101", path),
status: http.StatusBadRequest,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("ListProfiles", mock.Anything, tc.session, mock.Anything, mock.Anything, tc.name).Return(tc.svcPage, tc.svcErr)
req := testRequest{
client: bs.Client(),
method: http.MethodGet,
url: tc.url,
token: tc.token,
}
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status %d got %d", tc.desc, tc.status, res.StatusCode))
authCall.Unset()
svcCall.Unset()
})
}
}
func TestProfileSlots(t *testing.T) { func TestProfileSlots(t *testing.T) {
bs, svc, auth := newBootstrapServer() bs, svc, auth := newBootstrapServer()
defer bs.Close() defer bs.Close()
@@ -1295,56 +1402,119 @@ func TestRenderPreview(t *testing.T) {
profile := bootstrap.Profile{ profile := bootstrap.Profile{
ID: profileID, ID: profileID,
Name: "gateway", Name: "gateway",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
ContentTemplate: `device={{ .Device.ID }} site={{ .Vars.site }} topic={{ index (index .Bindings "telemetry").Snapshot "topic" }}`, ContentTemplate: `device={{ .Device.ID }} site={{ .Vars.site }} topic={{ index (index .Bindings "telemetry").Snapshot "topic" }}`,
} }
authCall := auth.On("Authenticate", mock.Anything, validToken).Return(session, nil)
svcCall := svc.On("ViewProfile", mock.Anything, session, profileID).Return(profile, nil)
reqBody := struct { storedConfig := bootstrap.Config{
ID: configID,
ExternalID: "gw-001",
DomainID: domainID,
RenderContext: map[string]any{
"site": "warehouse-1",
},
}
storedBindings := []bootstrap.BindingSnapshot{
{
Slot: "telemetry",
Type: "channel",
ResourceID: "ch-1",
Snapshot: map[string]any{"topic": "devices/gw-001/telemetry"},
},
}
inlineReqBody := struct {
Config bootstrap.Config `json:"config"` Config bootstrap.Config `json:"config"`
Bindings []bootstrap.BindingSnapshot `json:"bindings"` Bindings []bootstrap.BindingSnapshot `json:"bindings"`
}{ }{
Config: bootstrap.Config{ Config: bootstrap.Config{
ID: configID, ID: configID,
ExternalID: "gw-001", ExternalID: "gw-001",
RenderContext: map[string]any{ RenderContext: map[string]any{"site": "warehouse-1"},
"site": "warehouse-1",
},
}, },
Bindings: []bootstrap.BindingSnapshot{ Bindings: storedBindings,
{ }
Slot: "telemetry",
Type: "channel", configIDReqBody := struct {
ResourceID: "ch-1", ConfigID string `json:"config_id"`
Snapshot: map[string]any{ }{
"topic": "devices/gw-001/telemetry", ConfigID: configID,
}, }
},
expectedContent := "device=" + configID + " site=warehouse-1 topic=devices/gw-001/telemetry"
cases := []struct {
desc string
body string
profileErr error
configErr error
bindingsErr error
status int
}{
{
desc: "render preview with inline config and bindings",
body: toJSON(inlineReqBody),
status: http.StatusOK,
},
{
desc: "render preview with config_id loads from db",
body: toJSON(configIDReqBody),
status: http.StatusOK,
},
{
desc: "render preview with config_id and config not found",
body: toJSON(configIDReqBody),
configErr: svcerr.ErrNotFound,
status: http.StatusNotFound,
},
{
desc: "render preview with config_id and bindings error",
body: toJSON(configIDReqBody),
bindingsErr: svcerr.ErrViewEntity,
status: http.StatusUnprocessableEntity,
},
{
desc: "render preview with profile not found",
body: toJSON(inlineReqBody),
profileErr: svcerr.ErrNotFound,
status: http.StatusNotFound,
}, },
} }
req := testRequest{ for _, tc := range cases {
client: bs.Client(), t.Run(tc.desc, func(t *testing.T) {
method: http.MethodPost, authCall := auth.On("Authenticate", mock.Anything, validToken).Return(session, nil)
url: fmt.Sprintf("%s/%s/clients/bootstrap/profiles/%s/render-preview", bs.URL, domainID, profileID), svcCall := svc.On("ViewProfile", mock.Anything, session, profileID).Return(profile, tc.profileErr)
contentType: contentType, svcCall2 := svc.On("View", mock.Anything, session, configID).Return(storedConfig, tc.configErr)
token: validToken, svcCall3 := svc.On("ListBindings", mock.Anything, session, configID).Return(storedBindings, tc.bindingsErr)
body: strings.NewReader(toJSON(reqBody)),
}
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("render preview unexpected error %s", err))
assert.Equal(t, http.StatusOK, res.StatusCode, fmt.Sprintf("expected status code %d got %d", http.StatusOK, res.StatusCode))
var got struct { req := testRequest{
Content string `json:"content"` client: bs.Client(),
} method: http.MethodPost,
err = json.NewDecoder(res.Body).Decode(&got) url: fmt.Sprintf("%s/%s/clients/bootstrap/profiles/%s/render-preview", bs.URL, domainID, profileID),
assert.Nil(t, err, fmt.Sprintf("decoding render preview expected to succeed: %s", err)) contentType: contentType,
assert.Equal(t, "device="+configID+" site=warehouse-1 topic=devices/gw-001/telemetry", got.Content) token: validToken,
body: strings.NewReader(tc.body),
}
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset() if tc.status == http.StatusOK {
authCall.Unset() var got struct {
Content string `json:"content"`
}
err = json.NewDecoder(res.Body).Decode(&got)
assert.Nil(t, err, fmt.Sprintf("%s: decoding expected to succeed: %s", tc.desc, err))
assert.Equal(t, expectedContent, got.Content, fmt.Sprintf("%s: expected content %q got %q", tc.desc, expectedContent, got.Content))
}
svcCall3.Unset()
svcCall2.Unset()
svcCall.Unset()
authCall.Unset()
})
}
} }
type config struct { type config struct {
+2
View File
@@ -178,6 +178,7 @@ func (req updateProfileReq) validate() error {
type renderPreviewReq struct { type renderPreviewReq struct {
profileID string profileID string
ConfigID string `json:"config_id,omitempty"`
Config bootstrap.Config `json:"config"` Config bootstrap.Config `json:"config"`
RenderContext map[string]any `json:"render_context,omitempty"` RenderContext map[string]any `json:"render_context,omitempty"`
Bindings []bootstrap.BindingSnapshot `json:"bindings,omitempty"` Bindings []bootstrap.BindingSnapshot `json:"bindings,omitempty"`
@@ -204,6 +205,7 @@ func (req deleteProfileReq) validate() error {
type listProfilesReq struct { type listProfilesReq struct {
offset uint64 offset uint64
limit uint64 limit uint64
name string
} }
func (req listProfilesReq) validate() error { func (req listProfilesReq) validate() error {
+13 -1
View File
@@ -368,13 +368,16 @@ func decodeCreateProfileRequest(_ context.Context, r *http.Request) (any, error)
func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error) { func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error) {
contentType := r.Header.Get("Content-Type") contentType := r.Header.Get("Content-Type")
var req uploadProfileReq var req uploadProfileReq
var inferredFormat bootstrap.ContentFormat
switch { switch {
case strings.Contains(contentType, "json"): case strings.Contains(contentType, "json"):
inferredFormat = bootstrap.ContentFormatJSON
if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil { if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
} }
case strings.Contains(contentType, yamlContentType): case strings.Contains(contentType, yamlContentType):
inferredFormat = bootstrap.ContentFormatYAML
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
@@ -383,6 +386,7 @@ func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error)
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
} }
case strings.Contains(contentType, tomlContentType): case strings.Contains(contentType, tomlContentType):
inferredFormat = bootstrap.ContentFormatTOML
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
@@ -394,6 +398,10 @@ func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error)
return nil, apiutil.ErrUnsupportedContentType return nil, apiutil.ErrUnsupportedContentType
} }
if req.Profile.ContentFormat == "" {
req.Profile.ContentFormat = inferredFormat
}
return req, nil return req, nil
} }
@@ -430,7 +438,11 @@ func decodeListProfilesRequest(_ context.Context, r *http.Request) (any, error)
if err != nil { if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err) return nil, errors.Wrap(apiutil.ErrValidation, err)
} }
return listProfilesReq{offset: o, limit: l}, nil n, err := apiutil.ReadStringQuery(r, api.NameKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
return listProfilesReq{offset: o, limit: l, name: n}, nil
} }
func decodeProfileEntityRequest(_ context.Context, r *http.Request) (any, error) { func decodeProfileEntityRequest(_ context.Context, r *http.Request) (any, error) {
+1 -1
View File
@@ -98,7 +98,7 @@ func mergeBindingSnapshots(existing, updated []BindingSnapshot) []BindingSnapsho
} }
func validateProfileTemplate(p Profile) error { func validateProfileTemplate(p Profile) error {
if p.ContentTemplate == "" || p.TemplateFormat == TemplateFormatRaw { if p.ContentTemplate == "" || p.ContentFormat == ContentFormatRaw {
return nil return nil
} }
_, err := template.New("bootstrap").Funcs(allowlistedFuncs()).Parse(p.ContentTemplate) _, err := template.New("bootstrap").Funcs(allowlistedFuncs()).Parse(p.ContentTemplate)
+8 -7
View File
@@ -214,16 +214,17 @@ func (es *eventStore) ViewProfile(ctx context.Context, session smqauthn.Session,
return p, nil return p, nil
} }
func (es *eventStore) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) error { func (es *eventStore) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
if err := es.svc.UpdateProfile(ctx, session, p); err != nil { updated, err := es.svc.UpdateProfile(ctx, session, p)
return err if err != nil {
return bootstrap.Profile{}, err
} }
ev := profileEvent{p, profileUpdate} ev := profileEvent{updated, profileUpdate}
return es.Publish(ctx, updateProfileStream, ev) return updated, es.Publish(ctx, updateProfileStream, ev)
} }
func (es *eventStore) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (bootstrap.ProfilesPage, error) { func (es *eventStore) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
pp, err := es.svc.ListProfiles(ctx, session, offset, limit) pp, err := es.svc.ListProfiles(ctx, session, offset, limit, name)
if err != nil { if err != nil {
return pp, err return pp, err
} }
+4 -4
View File
@@ -124,18 +124,18 @@ func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session smqa
return am.svc.ViewProfile(ctx, session, profileID) return am.svc.ViewProfile(ctx, session, profileID)
} }
func (am *authorizationMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) error { func (am *authorizationMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, auth.AnyIDs); err != nil { if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, auth.AnyIDs); err != nil {
return err return bootstrap.Profile{}, err
} }
return am.svc.UpdateProfile(ctx, session, p) return am.svc.UpdateProfile(ctx, session, p)
} }
func (am *authorizationMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (bootstrap.ProfilesPage, error) { func (am *authorizationMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, listOperation, auth.AnyIDs); err != nil { if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, listOperation, auth.AnyIDs); err != nil {
return bootstrap.ProfilesPage{}, err return bootstrap.ProfilesPage{}, err
} }
return am.svc.ListProfiles(ctx, session, offset, limit) return am.svc.ListProfiles(ctx, session, offset, limit, name)
} }
func (am *authorizationMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error { func (am *authorizationMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
+3 -3
View File
@@ -233,7 +233,7 @@ func (lm *loggingMiddleware) ViewProfile(ctx context.Context, session smqauthn.S
return lm.svc.ViewProfile(ctx, session, profileID) return lm.svc.ViewProfile(ctx, session, profileID)
} }
func (lm *loggingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (err error) { func (lm *loggingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (updated bootstrap.Profile, err error) {
defer func(begin time.Time) { defer func(begin time.Time) {
args := []any{ args := []any{
slog.String("duration", time.Since(begin).String()), slog.String("duration", time.Since(begin).String()),
@@ -250,7 +250,7 @@ func (lm *loggingMiddleware) UpdateProfile(ctx context.Context, session smqauthn
return lm.svc.UpdateProfile(ctx, session, p) return lm.svc.UpdateProfile(ctx, session, p)
} }
func (lm *loggingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (page bootstrap.ProfilesPage, err error) { func (lm *loggingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (page bootstrap.ProfilesPage, err error) {
defer func(begin time.Time) { defer func(begin time.Time) {
args := []any{ args := []any{
slog.String("duration", time.Since(begin).String()), slog.String("duration", time.Since(begin).String()),
@@ -265,7 +265,7 @@ func (lm *loggingMiddleware) ListProfiles(ctx context.Context, session smqauthn.
lm.logger.Info("List profiles completed successfully", args...) lm.logger.Info("List profiles completed successfully", args...)
}(time.Now()) }(time.Now())
return lm.svc.ListProfiles(ctx, session, offset, limit) return lm.svc.ListProfiles(ctx, session, offset, limit, name)
} }
func (lm *loggingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) (err error) { func (lm *loggingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) (err error) {
+3 -3
View File
@@ -135,7 +135,7 @@ func (mm *metricsMiddleware) ViewProfile(ctx context.Context, session smqauthn.S
return mm.svc.ViewProfile(ctx, session, profileID) return mm.svc.ViewProfile(ctx, session, profileID)
} }
func (mm *metricsMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) error { func (mm *metricsMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
defer func(begin time.Time) { defer func(begin time.Time) {
mm.counter.With("method", "update_profile").Add(1) mm.counter.With("method", "update_profile").Add(1)
mm.latency.With("method", "update_profile").Observe(time.Since(begin).Seconds()) mm.latency.With("method", "update_profile").Observe(time.Since(begin).Seconds())
@@ -143,12 +143,12 @@ func (mm *metricsMiddleware) UpdateProfile(ctx context.Context, session smqauthn
return mm.svc.UpdateProfile(ctx, session, p) return mm.svc.UpdateProfile(ctx, session, p)
} }
func (mm *metricsMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (bootstrap.ProfilesPage, error) { func (mm *metricsMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
defer func(begin time.Time) { defer func(begin time.Time) {
mm.counter.With("method", "list_profiles").Add(1) mm.counter.With("method", "list_profiles").Add(1)
mm.latency.With("method", "list_profiles").Observe(time.Since(begin).Seconds()) mm.latency.With("method", "list_profiles").Observe(time.Since(begin).Seconds())
}(time.Now()) }(time.Now())
return mm.svc.ListProfiles(ctx, session, offset, limit) return mm.svc.ListProfiles(ctx, session, offset, limit, name)
} }
func (mm *metricsMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error { func (mm *metricsMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
+35 -20
View File
@@ -106,8 +106,8 @@ func (_c *ProfileRepository_Delete_Call) RunAndReturn(run func(ctx context.Conte
} }
// RetrieveAll provides a mock function for the type ProfileRepository // RetrieveAll provides a mock function for the type ProfileRepository
func (_mock *ProfileRepository) RetrieveAll(ctx context.Context, domainID string, offset uint64, limit uint64) (bootstrap.ProfilesPage, error) { func (_mock *ProfileRepository) RetrieveAll(ctx context.Context, domainID string, offset uint64, limit uint64, name string) (bootstrap.ProfilesPage, error) {
ret := _mock.Called(ctx, domainID, offset, limit) ret := _mock.Called(ctx, domainID, offset, limit, name)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for RetrieveAll") panic("no return value specified for RetrieveAll")
@@ -115,16 +115,16 @@ func (_mock *ProfileRepository) RetrieveAll(ctx context.Context, domainID string
var r0 bootstrap.ProfilesPage var r0 bootstrap.ProfilesPage
var r1 error var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) (bootstrap.ProfilesPage, error)); ok { if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64, string) (bootstrap.ProfilesPage, error)); ok {
return returnFunc(ctx, domainID, offset, limit) return returnFunc(ctx, domainID, offset, limit, name)
} }
if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) bootstrap.ProfilesPage); ok { if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64, string) bootstrap.ProfilesPage); ok {
r0 = returnFunc(ctx, domainID, offset, limit) r0 = returnFunc(ctx, domainID, offset, limit, name)
} else { } else {
r0 = ret.Get(0).(bootstrap.ProfilesPage) r0 = ret.Get(0).(bootstrap.ProfilesPage)
} }
if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint64, uint64) error); ok { if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint64, uint64, string) error); ok {
r1 = returnFunc(ctx, domainID, offset, limit) r1 = returnFunc(ctx, domainID, offset, limit, name)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@@ -141,11 +141,12 @@ type ProfileRepository_RetrieveAll_Call struct {
// - domainID string // - domainID string
// - offset uint64 // - offset uint64
// - limit uint64 // - limit uint64
func (_e *ProfileRepository_Expecter) RetrieveAll(ctx interface{}, domainID interface{}, offset interface{}, limit interface{}) *ProfileRepository_RetrieveAll_Call { // - name string
return &ProfileRepository_RetrieveAll_Call{Call: _e.mock.On("RetrieveAll", ctx, domainID, offset, limit)} func (_e *ProfileRepository_Expecter) RetrieveAll(ctx interface{}, domainID interface{}, offset interface{}, limit interface{}, name interface{}) *ProfileRepository_RetrieveAll_Call {
return &ProfileRepository_RetrieveAll_Call{Call: _e.mock.On("RetrieveAll", ctx, domainID, offset, limit, name)}
} }
func (_c *ProfileRepository_RetrieveAll_Call) Run(run func(ctx context.Context, domainID string, offset uint64, limit uint64)) *ProfileRepository_RetrieveAll_Call { func (_c *ProfileRepository_RetrieveAll_Call) Run(run func(ctx context.Context, domainID string, offset uint64, limit uint64, name string)) *ProfileRepository_RetrieveAll_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context var arg0 context.Context
if args[0] != nil { if args[0] != nil {
@@ -163,11 +164,16 @@ func (_c *ProfileRepository_RetrieveAll_Call) Run(run func(ctx context.Context,
if args[3] != nil { if args[3] != nil {
arg3 = args[3].(uint64) arg3 = args[3].(uint64)
} }
var arg4 string
if args[4] != nil {
arg4 = args[4].(string)
}
run( run(
arg0, arg0,
arg1, arg1,
arg2, arg2,
arg3, arg3,
arg4,
) )
}) })
return _c return _c
@@ -178,7 +184,7 @@ func (_c *ProfileRepository_RetrieveAll_Call) Return(profilesPage bootstrap.Prof
return _c return _c
} }
func (_c *ProfileRepository_RetrieveAll_Call) RunAndReturn(run func(ctx context.Context, domainID string, offset uint64, limit uint64) (bootstrap.ProfilesPage, error)) *ProfileRepository_RetrieveAll_Call { func (_c *ProfileRepository_RetrieveAll_Call) RunAndReturn(run func(ctx context.Context, domainID string, offset uint64, limit uint64, name string) (bootstrap.ProfilesPage, error)) *ProfileRepository_RetrieveAll_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
@@ -322,20 +328,29 @@ func (_c *ProfileRepository_Save_Call) RunAndReturn(run func(ctx context.Context
} }
// Update provides a mock function for the type ProfileRepository // Update provides a mock function for the type ProfileRepository
func (_mock *ProfileRepository) Update(ctx context.Context, p bootstrap.Profile) error { func (_mock *ProfileRepository) Update(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error) {
ret := _mock.Called(ctx, p) ret := _mock.Called(ctx, p)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Update") panic("no return value specified for Update")
} }
var r0 error var r0 bootstrap.Profile
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Profile) error); ok { var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Profile) (bootstrap.Profile, error)); ok {
return returnFunc(ctx, p)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Profile) bootstrap.Profile); ok {
r0 = returnFunc(ctx, p) r0 = returnFunc(ctx, p)
} else { } else {
r0 = ret.Error(0) r0 = ret.Get(0).(bootstrap.Profile)
} }
return r0 if returnFunc, ok := ret.Get(1).(func(context.Context, bootstrap.Profile) error); ok {
r1 = returnFunc(ctx, p)
} else {
r1 = ret.Error(1)
}
return r0, r1
} }
// ProfileRepository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' // ProfileRepository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update'
@@ -368,12 +383,12 @@ func (_c *ProfileRepository_Update_Call) Run(run func(ctx context.Context, p boo
return _c return _c
} }
func (_c *ProfileRepository_Update_Call) Return(err error) *ProfileRepository_Update_Call { func (_c *ProfileRepository_Update_Call) Return(profile bootstrap.Profile, err error) *ProfileRepository_Update_Call {
_c.Call.Return(err) _c.Call.Return(profile, err)
return _c return _c
} }
func (_c *ProfileRepository_Update_Call) RunAndReturn(run func(ctx context.Context, p bootstrap.Profile) error) *ProfileRepository_Update_Call { func (_c *ProfileRepository_Update_Call) RunAndReturn(run func(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error)) *ProfileRepository_Update_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
+35 -20
View File
@@ -781,8 +781,8 @@ func (_c *Service_ListBindings_Call) RunAndReturn(run func(ctx context.Context,
} }
// ListProfiles provides a mock function for the type Service // ListProfiles provides a mock function for the type Service
func (_mock *Service) ListProfiles(ctx context.Context, session authn.Session, offset uint64, limit uint64) (bootstrap.ProfilesPage, error) { func (_mock *Service) ListProfiles(ctx context.Context, session authn.Session, offset uint64, limit uint64, name string) (bootstrap.ProfilesPage, error) {
ret := _mock.Called(ctx, session, offset, limit) ret := _mock.Called(ctx, session, offset, limit, name)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for ListProfiles") panic("no return value specified for ListProfiles")
@@ -790,16 +790,16 @@ func (_mock *Service) ListProfiles(ctx context.Context, session authn.Session, o
var r0 bootstrap.ProfilesPage var r0 bootstrap.ProfilesPage
var r1 error var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, uint64, uint64) (bootstrap.ProfilesPage, error)); ok { if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, uint64, uint64, string) (bootstrap.ProfilesPage, error)); ok {
return returnFunc(ctx, session, offset, limit) return returnFunc(ctx, session, offset, limit, name)
} }
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, uint64, uint64) bootstrap.ProfilesPage); ok { if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, uint64, uint64, string) bootstrap.ProfilesPage); ok {
r0 = returnFunc(ctx, session, offset, limit) r0 = returnFunc(ctx, session, offset, limit, name)
} else { } else {
r0 = ret.Get(0).(bootstrap.ProfilesPage) r0 = ret.Get(0).(bootstrap.ProfilesPage)
} }
if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, uint64, uint64) error); ok { if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, uint64, uint64, string) error); ok {
r1 = returnFunc(ctx, session, offset, limit) r1 = returnFunc(ctx, session, offset, limit, name)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@@ -816,11 +816,12 @@ type Service_ListProfiles_Call struct {
// - session authn.Session // - session authn.Session
// - offset uint64 // - offset uint64
// - limit uint64 // - limit uint64
func (_e *Service_Expecter) ListProfiles(ctx interface{}, session interface{}, offset interface{}, limit interface{}) *Service_ListProfiles_Call { // - name string
return &Service_ListProfiles_Call{Call: _e.mock.On("ListProfiles", ctx, session, offset, limit)} func (_e *Service_Expecter) ListProfiles(ctx interface{}, session interface{}, offset interface{}, limit interface{}, name interface{}) *Service_ListProfiles_Call {
return &Service_ListProfiles_Call{Call: _e.mock.On("ListProfiles", ctx, session, offset, limit, name)}
} }
func (_c *Service_ListProfiles_Call) Run(run func(ctx context.Context, session authn.Session, offset uint64, limit uint64)) *Service_ListProfiles_Call { func (_c *Service_ListProfiles_Call) Run(run func(ctx context.Context, session authn.Session, offset uint64, limit uint64, name string)) *Service_ListProfiles_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context var arg0 context.Context
if args[0] != nil { if args[0] != nil {
@@ -838,11 +839,16 @@ func (_c *Service_ListProfiles_Call) Run(run func(ctx context.Context, session a
if args[3] != nil { if args[3] != nil {
arg3 = args[3].(uint64) arg3 = args[3].(uint64)
} }
var arg4 string
if args[4] != nil {
arg4 = args[4].(string)
}
run( run(
arg0, arg0,
arg1, arg1,
arg2, arg2,
arg3, arg3,
arg4,
) )
}) })
return _c return _c
@@ -853,7 +859,7 @@ func (_c *Service_ListProfiles_Call) Return(profilesPage bootstrap.ProfilesPage,
return _c return _c
} }
func (_c *Service_ListProfiles_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, offset uint64, limit uint64) (bootstrap.ProfilesPage, error)) *Service_ListProfiles_Call { func (_c *Service_ListProfiles_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, offset uint64, limit uint64, name string) (bootstrap.ProfilesPage, error)) *Service_ListProfiles_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
@@ -1144,20 +1150,29 @@ func (_c *Service_UpdateCert_Call) RunAndReturn(run func(ctx context.Context, se
} }
// UpdateProfile provides a mock function for the type Service // UpdateProfile provides a mock function for the type Service
func (_mock *Service) UpdateProfile(ctx context.Context, session authn.Session, p bootstrap.Profile) error { func (_mock *Service) UpdateProfile(ctx context.Context, session authn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
ret := _mock.Called(ctx, session, p) ret := _mock.Called(ctx, session, p)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for UpdateProfile") panic("no return value specified for UpdateProfile")
} }
var r0 error var r0 bootstrap.Profile
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Profile) error); ok { var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Profile) (bootstrap.Profile, error)); ok {
return returnFunc(ctx, session, p)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Profile) bootstrap.Profile); ok {
r0 = returnFunc(ctx, session, p) r0 = returnFunc(ctx, session, p)
} else { } else {
r0 = ret.Error(0) r0 = ret.Get(0).(bootstrap.Profile)
} }
return r0 if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, bootstrap.Profile) error); ok {
r1 = returnFunc(ctx, session, p)
} else {
r1 = ret.Error(1)
}
return r0, r1
} }
// Service_UpdateProfile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateProfile' // Service_UpdateProfile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateProfile'
@@ -1196,12 +1211,12 @@ func (_c *Service_UpdateProfile_Call) Run(run func(ctx context.Context, session
return _c return _c
} }
func (_c *Service_UpdateProfile_Call) Return(err error) *Service_UpdateProfile_Call { func (_c *Service_UpdateProfile_Call) Return(profile bootstrap.Profile, err error) *Service_UpdateProfile_Call {
_c.Call.Return(err) _c.Call.Return(profile, err)
return _c return _c
} }
func (_c *Service_UpdateProfile_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, p bootstrap.Profile) error) *Service_UpdateProfile_Call { func (_c *Service_UpdateProfile_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, p bootstrap.Profile) (bootstrap.Profile, error)) *Service_UpdateProfile_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
+5 -5
View File
@@ -454,11 +454,11 @@ func TestAssignProfile(t *testing.T) {
profileID := testsutil.GenerateUUID(t) profileID := testsutil.GenerateUUID(t)
_, err = profileRepo.Save(context.Background(), bootstrap.Profile{ _, err = profileRepo.Save(context.Background(), bootstrap.Profile{
ID: profileID, ID: profileID,
DomainID: c.DomainID, DomainID: c.DomainID,
Name: "edge-gateway", Name: "edge-gateway",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
Version: 1, Version: 1,
}) })
require.Nil(t, err, fmt.Sprintf("Saving profile expected to succeed: %s.\n", err)) require.Nil(t, err, fmt.Sprintf("Saving profile expected to succeed: %s.\n", err))
+9
View File
@@ -315,6 +315,15 @@ func Migration() *migrate.MemoryMigrationSource {
`ALTER TABLE IF EXISTS profiles DROP COLUMN IF EXISTS binding_slots`, `ALTER TABLE IF EXISTS profiles DROP COLUMN IF EXISTS binding_slots`,
}, },
}, },
{
Id: "configs_16",
Up: []string{
`ALTER TABLE IF EXISTS profiles RENAME COLUMN template_format TO content_format`,
},
Down: []string{
`ALTER TABLE IF EXISTS profiles RENAME COLUMN content_format TO template_format`,
},
},
}, },
} }
} }
+98 -44
View File
@@ -9,14 +9,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"strings"
"time" "time"
"github.com/absmach/magistrala/bootstrap" "github.com/absmach/magistrala/bootstrap"
"github.com/absmach/magistrala/pkg/errors" "github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository" repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/absmach/magistrala/pkg/postgres" "github.com/absmach/magistrala/pkg/postgres"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
) )
var _ bootstrap.ProfileRepository = (*profileRepository)(nil) var _ bootstrap.ProfileRepository = (*profileRepository)(nil)
@@ -32,8 +31,8 @@ func NewProfileRepository(db postgres.Database, log *slog.Logger) bootstrap.Prof
} }
func (pr profileRepository) Save(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error) { func (pr profileRepository) Save(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error) {
q := `INSERT INTO profiles (id, domain_id, name, description, template_format, content_template, defaults, binding_slots, version, created_at, updated_at) q := `INSERT INTO profiles (id, domain_id, name, description, content_format, content_template, defaults, binding_slots, version, created_at, updated_at)
VALUES (:id, :domain_id, :name, :description, :template_format, :content_template, :defaults, :binding_slots, :version, :created_at, :updated_at)` VALUES (:id, :domain_id, :name, :description, :content_format, :content_template, :defaults, :binding_slots, :version, :created_at, :updated_at)`
now := time.Now().UTC() now := time.Now().UTC()
p.CreatedAt = now p.CreatedAt = now
@@ -45,35 +44,42 @@ func (pr profileRepository) Save(ctx context.Context, p bootstrap.Profile) (boot
} }
if _, err = pr.db.NamedExecContext(ctx, q, dbp); err != nil { if _, err = pr.db.NamedExecContext(ctx, q, dbp); err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == pgerrcode.UniqueViolation { return bootstrap.Profile{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
return bootstrap.Profile{}, repoerr.ErrConflict
}
return bootstrap.Profile{}, errors.Wrap(repoerr.ErrCreateEntity, err)
} }
return p, nil return p, nil
} }
func (pr profileRepository) RetrieveByID(ctx context.Context, domainID, id string) (bootstrap.Profile, error) { func (pr profileRepository) RetrieveByID(ctx context.Context, domainID, id string) (bootstrap.Profile, error) {
q := `SELECT id, domain_id, name, description, template_format, content_template, defaults, binding_slots, version, created_at, updated_at q := `SELECT id, domain_id, name, description, content_format, content_template, defaults, binding_slots, version, created_at, updated_at
FROM profiles WHERE id = $1 AND domain_id = $2` FROM profiles WHERE id = :id AND domain_id = :domain_id`
rows, err := pr.db.NamedQueryContext(ctx, q, dbProfile{ID: id, DomainID: domainID})
if err != nil {
return bootstrap.Profile{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer rows.Close()
if !rows.Next() {
return bootstrap.Profile{}, repoerr.ErrNotFound
}
var dbp dbProfile var dbp dbProfile
if err := pr.db.QueryRowxContext(ctx, q, id, domainID).StructScan(&dbp); err != nil { if err := rows.StructScan(&dbp); err != nil {
if err == sql.ErrNoRows {
return bootstrap.Profile{}, repoerr.ErrNotFound
}
return bootstrap.Profile{}, errors.Wrap(repoerr.ErrViewEntity, err) return bootstrap.Profile{}, errors.Wrap(repoerr.ErrViewEntity, err)
} }
return toProfile(dbp) return toProfile(dbp)
} }
func (pr profileRepository) RetrieveAll(ctx context.Context, domainID string, offset, limit uint64) (bootstrap.ProfilesPage, error) { func (pr profileRepository) RetrieveAll(ctx context.Context, domainID string, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
q := `SELECT id, domain_id, name, description, template_format, content_template, defaults, binding_slots, version, created_at, updated_at dbPage := dbProfilesPage{DomainID: domainID, Offset: offset, Limit: limit, Name: name}
FROM profiles WHERE domain_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3` pageQuery := profilesPageQuery(dbPage)
q := fmt.Sprintf(`SELECT id, domain_id, name, description, content_format, content_template, defaults, binding_slots, version, created_at, updated_at
FROM profiles %s`, pageQuery)
q = applyProfilesOrdering(q)
q = fmt.Sprintf(`%s LIMIT :limit OFFSET :offset`, q)
rows, err := pr.db.QueryxContext(ctx, q, domainID, limit, offset) rows, err := pr.db.NamedQueryContext(ctx, q, dbPage)
if err != nil { if err != nil {
return bootstrap.ProfilesPage{}, errors.Wrap(repoerr.ErrViewEntity, err) return bootstrap.ProfilesPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
} }
@@ -93,8 +99,9 @@ func (pr profileRepository) RetrieveAll(ctx context.Context, domainID string, of
profiles = append(profiles, p) profiles = append(profiles, p)
} }
var total uint64 cq := fmt.Sprintf(`SELECT COUNT(*) FROM profiles %s`, pageQuery)
if err := pr.db.QueryRowxContext(ctx, `SELECT COUNT(*) FROM profiles WHERE domain_id = $1`, domainID).Scan(&total); err != nil { total, err := postgres.Total(ctx, pr.db, cq, dbPage)
if err != nil {
return bootstrap.ProfilesPage{}, errors.Wrap(repoerr.ErrViewEntity, err) return bootstrap.ProfilesPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
} }
@@ -106,35 +113,82 @@ func (pr profileRepository) RetrieveAll(ctx context.Context, domainID string, of
}, nil }, nil
} }
func (pr profileRepository) Update(ctx context.Context, p bootstrap.Profile) error { type dbProfilesPage struct {
q := `UPDATE profiles SET name = :name, description = :description, template_format = :template_format, DomainID string `db:"domain_id"`
content_template = :content_template, defaults = :defaults, binding_slots = :binding_slots, version = version + 1, updated_at = :updated_at Offset uint64 `db:"offset"`
WHERE id = :id AND domain_id = :domain_id` Limit uint64 `db:"limit"`
Name string `db:"name"`
}
func profilesPageQuery(pm dbProfilesPage) string {
var query []string
query = append(query, "domain_id = :domain_id")
if pm.Name != "" {
query = append(query, "name ILIKE '%' || :name || '%'")
}
return fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
}
func applyProfilesOrdering(q string) string {
return fmt.Sprintf("%s ORDER BY created_at DESC", q)
}
func (pr profileRepository) Update(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error) {
var query []string
var upq string
if p.Name != "" {
query = append(query, "name = :name,")
}
if p.Description != "" {
query = append(query, "description = :description,")
}
if p.ContentFormat != "" {
query = append(query, "content_format = :content_format,")
}
if p.ContentTemplate != "" {
query = append(query, "content_template = :content_template,")
}
if p.Defaults != nil {
query = append(query, "defaults = :defaults,")
}
if p.BindingSlots != nil {
query = append(query, "binding_slots = :binding_slots,")
}
if len(query) > 0 {
upq = strings.Join(query, " ")
}
q := fmt.Sprintf(`UPDATE profiles SET %s version = version + 1, updated_at = :updated_at
WHERE id = :id AND domain_id = :domain_id
RETURNING id, domain_id, name, description, content_format, content_template, defaults, binding_slots, version, created_at, updated_at`,
upq)
p.UpdatedAt = time.Now().UTC() p.UpdatedAt = time.Now().UTC()
dbp, err := toDBProfile(p) dbp, err := toDBProfile(p)
if err != nil { if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err) return bootstrap.Profile{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
} }
res, err := pr.db.NamedExecContext(ctx, q, dbp) rows, err := pr.db.NamedQueryContext(ctx, q, dbp)
if err != nil { if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err) return bootstrap.Profile{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
} }
cnt, err := res.RowsAffected() defer rows.Close()
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err) if !rows.Next() {
return bootstrap.Profile{}, repoerr.ErrNotFound
} }
if cnt == 0 { var updated dbProfile
return repoerr.ErrNotFound if err := rows.StructScan(&updated); err != nil {
return bootstrap.Profile{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
} }
return nil return toProfile(updated)
} }
func (pr profileRepository) Delete(ctx context.Context, domainID, id string) error { func (pr profileRepository) Delete(ctx context.Context, domainID, id string) error {
q := `DELETE FROM profiles WHERE id = $1 AND domain_id = $2` q := `DELETE FROM profiles WHERE id = :id AND domain_id = :domain_id`
if _, err := pr.db.ExecContext(ctx, q, id, domainID); err != nil { if _, err := pr.db.NamedExecContext(ctx, q, dbProfile{ID: id, DomainID: domainID}); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err) return errors.Wrap(repoerr.ErrRemoveEntity, err)
} }
return nil return nil
@@ -146,7 +200,7 @@ type dbProfile struct {
DomainID string `db:"domain_id"` DomainID string `db:"domain_id"`
Name string `db:"name"` Name string `db:"name"`
Description sql.NullString `db:"description"` Description sql.NullString `db:"description"`
TemplateFormat string `db:"template_format"` ContentFormat string `db:"content_format"`
ContentTemplate sql.NullString `db:"content_template"` ContentTemplate sql.NullString `db:"content_template"`
Defaults []byte `db:"defaults"` Defaults []byte `db:"defaults"`
BindingSlots []byte `db:"binding_slots"` BindingSlots []byte `db:"binding_slots"`
@@ -169,7 +223,7 @@ func toDBProfile(p bootstrap.Profile) (dbProfile, error) {
DomainID: p.DomainID, DomainID: p.DomainID,
Name: p.Name, Name: p.Name,
Description: nullString(p.Description), Description: nullString(p.Description),
TemplateFormat: string(p.TemplateFormat), ContentFormat: string(p.ContentFormat),
ContentTemplate: nullString(p.ContentTemplate), ContentTemplate: nullString(p.ContentTemplate),
Defaults: defaults, Defaults: defaults,
BindingSlots: bindingSlots, BindingSlots: bindingSlots,
@@ -181,13 +235,13 @@ func toDBProfile(p bootstrap.Profile) (dbProfile, error) {
func toProfile(dbp dbProfile) (bootstrap.Profile, error) { func toProfile(dbp dbProfile) (bootstrap.Profile, error) {
p := bootstrap.Profile{ p := bootstrap.Profile{
ID: dbp.ID, ID: dbp.ID,
DomainID: dbp.DomainID, DomainID: dbp.DomainID,
Name: dbp.Name, Name: dbp.Name,
TemplateFormat: bootstrap.TemplateFormat(dbp.TemplateFormat), ContentFormat: bootstrap.ContentFormat(dbp.ContentFormat),
Version: dbp.Version, Version: dbp.Version,
CreatedAt: dbp.CreatedAt, CreatedAt: dbp.CreatedAt,
UpdatedAt: dbp.UpdatedAt, UpdatedAt: dbp.UpdatedAt,
} }
if dbp.Description.Valid { if dbp.Description.Valid {
p.Description = dbp.Description.String p.Description = dbp.Description.String
+12 -12
View File
@@ -8,15 +8,15 @@ import (
"time" "time"
) )
// TemplateFormat enumerates supported content template formats. // ContentFormat enumerates the supported output formats for rendered profile templates.
type TemplateFormat string type ContentFormat string
const ( const (
TemplateFormatGoTemplate TemplateFormat = "go-template" ContentFormatGoTemplate ContentFormat = "go-template"
TemplateFormatRaw TemplateFormat = "raw" ContentFormatRaw ContentFormat = "raw"
TemplateFormatJSON TemplateFormat = "json" ContentFormatJSON ContentFormat = "json"
TemplateFormatYAML TemplateFormat = "yaml" ContentFormatYAML ContentFormat = "yaml"
TemplateFormatTOML TemplateFormat = "toml" ContentFormatTOML ContentFormat = "toml"
) )
// Profile is a user-managed device configuration template. // Profile is a user-managed device configuration template.
@@ -25,7 +25,7 @@ type Profile struct {
DomainID string `json:"domain_id,omitempty"` DomainID string `json:"domain_id,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
TemplateFormat TemplateFormat `json:"template_format"` ContentFormat ContentFormat `json:"content_format"`
ContentTemplate string `json:"content_template,omitempty"` ContentTemplate string `json:"content_template,omitempty"`
Defaults map[string]any `json:"defaults,omitempty"` Defaults map[string]any `json:"defaults,omitempty"`
BindingSlots []BindingSlot `json:"binding_slots,omitempty"` BindingSlots []BindingSlot `json:"binding_slots,omitempty"`
@@ -58,11 +58,11 @@ type ProfileRepository interface {
// RetrieveByID returns the Profile with the given ID inside the given domain. // RetrieveByID returns the Profile with the given ID inside the given domain.
RetrieveByID(ctx context.Context, domainID, id string) (Profile, error) RetrieveByID(ctx context.Context, domainID, id string) (Profile, error)
// RetrieveAll returns a page of Profiles belonging to the given domain. // RetrieveAll returns a page of Profiles belonging to the given domain, optionally filtered by name.
RetrieveAll(ctx context.Context, domainID string, offset, limit uint64) (ProfilesPage, error) RetrieveAll(ctx context.Context, domainID string, offset, limit uint64, name string) (ProfilesPage, error)
// Update updates editable fields of the given Profile. // Update updates editable fields of the given Profile and returns the updated Profile.
Update(ctx context.Context, p Profile) error Update(ctx context.Context, p Profile) (Profile, error)
// Delete removes the Profile with the given ID from the given domain. // Delete removes the Profile with the given ID from the given domain.
Delete(ctx context.Context, domainID, id string) error Delete(ctx context.Context, domainID, id string) error
+52 -29
View File
@@ -34,13 +34,13 @@ func NewRenderer() Renderer {
func (r renderer) Render(profile Profile, enrollment Config, bindings []BindingSnapshot) ([]byte, error) { func (r renderer) Render(profile Profile, enrollment Config, bindings []BindingSnapshot) ([]byte, error) {
rctx := buildRenderContext(profile, enrollment, bindings) rctx := buildRenderContext(profile, enrollment, bindings)
switch profile.TemplateFormat { switch profile.ContentFormat {
case TemplateFormatRaw: case ContentFormatRaw:
return []byte(profile.ContentTemplate), nil return []byte(profile.ContentTemplate), nil
case TemplateFormatGoTemplate, TemplateFormatJSON, TemplateFormatYAML, TemplateFormatTOML, "": case ContentFormatGoTemplate, ContentFormatJSON, ContentFormatYAML, ContentFormatTOML, "":
return r.renderTemplate(profile, rctx) return r.renderTemplate(profile, rctx)
default: default:
return nil, fmt.Errorf("%w: unsupported template format %q", ErrRenderFailed, profile.TemplateFormat) return nil, fmt.Errorf("%w: unsupported template format %q", ErrRenderFailed, profile.ContentFormat)
} }
} }
@@ -58,38 +58,61 @@ func (r renderer) renderTemplate(profile Profile, rctx RenderContext) ([]byte, e
return nil, fmt.Errorf("%w: %w", ErrRenderFailed, err) return nil, fmt.Errorf("%w: %w", ErrRenderFailed, err)
} }
out, err := validateRenderedOutput(buf.Bytes(), profile.TemplateFormat) return convertOutput(buf.Bytes(), profile.ContentFormat)
if err != nil {
return nil, err
}
return out, nil
} }
// validateRenderedOutput checks that the rendered bytes are valid for the // convertOutput parses the rendered bytes as any structured format (JSON, YAML,
// declared output format. It returns the original bytes on success and wraps // or TOML) and re-marshals them into the declared target format. For go-template
// ErrRenderFailed on failure. // or empty format the raw bytes are returned unchanged.
func validateRenderedOutput(out []byte, format TemplateFormat) ([]byte, error) { func convertOutput(out []byte, format ContentFormat) ([]byte, error) {
// Unrecognised formats are passed through. Recognised structured formats
// must parse successfully so broken templates fail before reaching devices.
switch format { switch format {
case TemplateFormatJSON: case ContentFormatGoTemplate, "":
return out, nil
case ContentFormatJSON, ContentFormatYAML, ContentFormatTOML:
var v any var v any
if err := json.Unmarshal(out, &v); err != nil { if err := parseStructured(out, &v); err != nil {
return nil, fmt.Errorf("%w: invalid json output: %w", ErrRenderFailed, err) return nil, fmt.Errorf("%w: %w", ErrRenderFailed, err)
} }
case TemplateFormatYAML: result, err := marshalAs(v, format)
var v any if err != nil {
if err := yaml.Unmarshal(out, &v); err != nil { return nil, fmt.Errorf("%w: %w", ErrRenderFailed, err)
return nil, fmt.Errorf("%w: invalid yaml output: %w", ErrRenderFailed, err)
}
case TemplateFormatTOML:
var v any
if err := toml.Unmarshal(out, &v); err != nil {
return nil, fmt.Errorf("%w: invalid toml output: %w", ErrRenderFailed, err)
} }
return result, nil
default:
return nil, fmt.Errorf("%w: unsupported format %q", ErrRenderFailed, format)
}
}
// parseStructured tries JSON, then YAML, then TOML and unmarshals into v.
func parseStructured(out []byte, v any) error {
if err := json.Unmarshal(out, v); err == nil {
return nil
}
if err := yaml.Unmarshal(out, v); err == nil {
return nil
}
if err := toml.Unmarshal(out, v); err == nil {
return nil
}
return fmt.Errorf("template output is not valid JSON, YAML, or TOML")
}
// marshalAs re-marshals v into the requested format.
func marshalAs(v any, format ContentFormat) ([]byte, error) {
switch format {
case ContentFormatJSON:
return json.MarshalIndent(v, "", " ")
case ContentFormatYAML:
return yaml.Marshal(v)
case ContentFormatTOML:
var buf bytes.Buffer
if err := toml.NewEncoder(&buf).Encode(v); err != nil {
return nil, err
}
return buf.Bytes(), nil
default:
return nil, fmt.Errorf("unsupported format %q", format)
} }
return out, nil
} }
// buildRenderContext constructs the typed RenderContext from stored data. // buildRenderContext constructs the typed RenderContext from stored data.
+29 -14
View File
@@ -17,50 +17,65 @@ func TestRendererStructuredOutputValidation(t *testing.T) {
cases := []struct { cases := []struct {
desc string desc string
format bootstrap.TemplateFormat format bootstrap.ContentFormat
template string template string
err error err error
}{ }{
{ {
desc: "valid JSON output", desc: "valid JSON output",
format: bootstrap.TemplateFormatJSON, format: bootstrap.ContentFormatJSON,
template: `{"device_id":"{{ .Device.ID }}"}`, template: `{"device_id":"{{ .Device.ID }}"}`,
}, },
{ {
desc: "invalid JSON output", desc: "invalid output for JSON format",
format: bootstrap.TemplateFormatJSON, format: bootstrap.ContentFormatJSON,
template: `{"device_id":`, template: `[unclosed bracket`,
err: bootstrap.ErrRenderFailed, err: bootstrap.ErrRenderFailed,
}, },
{ {
desc: "valid YAML output", desc: "valid YAML output",
format: bootstrap.TemplateFormatYAML, format: bootstrap.ContentFormatYAML,
template: "device_id: {{ .Device.ID }}", template: "device_id: {{ .Device.ID }}",
}, },
{ {
desc: "invalid YAML output", desc: "invalid output for YAML format",
format: bootstrap.TemplateFormatYAML, format: bootstrap.ContentFormatYAML,
template: "device_id: [", template: "[unclosed bracket",
err: bootstrap.ErrRenderFailed, err: bootstrap.ErrRenderFailed,
}, },
{ {
desc: "valid TOML output", desc: "valid TOML output",
format: bootstrap.TemplateFormatTOML, format: bootstrap.ContentFormatTOML,
template: `device_id = "{{ .Device.ID }}"`, template: `device_id = "{{ .Device.ID }}"`,
}, },
{ {
desc: "invalid TOML output", desc: "invalid output for TOML format",
format: bootstrap.TemplateFormatTOML, format: bootstrap.ContentFormatTOML,
template: `device_id = `, template: `[unclosed bracket`,
err: bootstrap.ErrRenderFailed, err: bootstrap.ErrRenderFailed,
}, },
{
desc: "JSON template auto-converted to TOML",
format: bootstrap.ContentFormatTOML,
template: `{"device_id":"{{ .Device.ID }}"}`,
},
{
desc: "TOML template auto-converted to JSON",
format: bootstrap.ContentFormatJSON,
template: `device_id = "{{ .Device.ID }}"`,
},
{
desc: "YAML template auto-converted to TOML",
format: bootstrap.ContentFormatTOML,
template: "device_id: {{ .Device.ID }}",
},
} }
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
_, err := renderer.Render( _, err := renderer.Render(
bootstrap.Profile{ bootstrap.Profile{
TemplateFormat: tc.format, ContentFormat: tc.format,
ContentTemplate: tc.template, ContentTemplate: tc.template,
}, },
bootstrap.Config{ID: "config-id"}, bootstrap.Config{ID: "config-id"},
+15 -17
View File
@@ -90,11 +90,11 @@ type Service interface {
// ViewProfile returns the Profile with the given ID. // ViewProfile returns the Profile with the given ID.
ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (Profile, error) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (Profile, error)
// UpdateProfile updates editable fields of the given Profile. // UpdateProfile updates editable fields of the given Profile and returns the updated Profile.
UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) error UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error)
// ListProfiles returns a page of Profiles belonging to the domain. // ListProfiles returns a page of Profiles belonging to the domain.
ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (ProfilesPage, error) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (ProfilesPage, error)
// DeleteProfile removes the Profile with the given ID. // DeleteProfile removes the Profile with the given ID.
DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error
@@ -328,8 +328,8 @@ func (bs bootstrapService) CreateProfile(ctx context.Context, session smqauthn.S
} }
p.ID = id p.ID = id
p.DomainID = session.DomainID p.DomainID = session.DomainID
if p.TemplateFormat == "" { if p.ContentFormat == "" {
p.TemplateFormat = TemplateFormatGoTemplate p.ContentFormat = ContentFormatJSON
} }
p.Version = 1 p.Version = 1
if err := validateProfileBindingSlots(p); err != nil { if err := validateProfileBindingSlots(p); err != nil {
@@ -356,31 +356,29 @@ func (bs bootstrapService) ViewProfile(ctx context.Context, session smqauthn.Ses
return p, nil return p, nil
} }
func (bs bootstrapService) UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) error { func (bs bootstrapService) UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error) {
if bs.profiles == nil { if bs.profiles == nil {
return errors.Wrap(errUpdateProfile, errors.New("profile repository not configured")) return Profile{}, errors.Wrap(errUpdateProfile, errors.New("profile repository not configured"))
} }
p.DomainID = session.DomainID p.DomainID = session.DomainID
if p.TemplateFormat == "" {
p.TemplateFormat = TemplateFormatGoTemplate
}
if err := validateProfileBindingSlots(p); err != nil { if err := validateProfileBindingSlots(p); err != nil {
return errors.Wrap(errUpdateProfile, err) return Profile{}, errors.Wrap(errUpdateProfile, err)
} }
if err := validateProfileTemplate(p); err != nil { if err := validateProfileTemplate(p); err != nil {
return errors.Wrap(errUpdateProfile, err) return Profile{}, errors.Wrap(errUpdateProfile, err)
} }
if err := bs.profiles.Update(ctx, p); err != nil { updated, err := bs.profiles.Update(ctx, p)
return errors.Wrap(errUpdateProfile, err) if err != nil {
return Profile{}, errors.Wrap(errUpdateProfile, err)
} }
return nil return updated, nil
} }
func (bs bootstrapService) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (ProfilesPage, error) { func (bs bootstrapService) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (ProfilesPage, error) {
if bs.profiles == nil { if bs.profiles == nil {
return ProfilesPage{}, errors.Wrap(errListProfiles, errors.New("profile repository not configured")) return ProfilesPage{}, errors.Wrap(errListProfiles, errors.New("profile repository not configured"))
} }
page, err := bs.profiles.RetrieveAll(ctx, session.DomainID, offset, limit) page, err := bs.profiles.RetrieveAll(ctx, session.DomainID, offset, limit, name)
if err != nil { if err != nil {
return ProfilesPage{}, errors.Wrap(errListProfiles, err) return ProfilesPage{}, errors.Wrap(errListProfiles, err)
} }
+44 -33
View File
@@ -735,7 +735,7 @@ func TestBootstrapRender(t *testing.T) {
ID: testsutil.GenerateUUID(&testing.T{}), ID: testsutil.GenerateUUID(&testing.T{}),
DomainID: domainID, DomainID: domainID,
Name: "gateway-profile", Name: "gateway-profile",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
ContentTemplate: `{"mode":"profile"}`, ContentTemplate: `{"mode":"profile"}`,
} }
bindings := []bootstrap.BindingSnapshot{ bindings := []bootstrap.BindingSnapshot{
@@ -968,11 +968,11 @@ func TestDisableConfig(t *testing.T) {
func TestAssignProfile(t *testing.T) { func TestAssignProfile(t *testing.T) {
profile := bootstrap.Profile{ profile := bootstrap.Profile{
ID: testsutil.GenerateUUID(t), ID: testsutil.GenerateUUID(t),
DomainID: domainID, DomainID: domainID,
Name: "gateway-profile", Name: "gateway-profile",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
Version: 1, Version: 1,
} }
cases := []struct { cases := []struct {
@@ -1034,23 +1034,26 @@ func TestCreateProfile(t *testing.T) {
session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID} session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}
validProfile := bootstrap.Profile{ validProfile := bootstrap.Profile{
Name: "test-profile", Name: "test-profile",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
} }
cases := []struct { cases := []struct {
desc string desc string
profile bootstrap.Profile profile bootstrap.Profile
saveErr error saveErr error
err error err error
wantFormat bootstrap.ContentFormat
}{ }{
{ {
desc: "create profile successfully", desc: "create profile successfully",
profile: validProfile, profile: validProfile,
wantFormat: bootstrap.ContentFormatGoTemplate,
}, },
{ {
desc: "create profile defaults to go-template format", desc: "create profile defaults to json format",
profile: bootstrap.Profile{Name: "no-format"}, profile: bootstrap.Profile{Name: "no-format"},
wantFormat: bootstrap.ContentFormatJSON,
}, },
{ {
desc: "create profile with invalid slot: empty name", desc: "create profile with invalid slot: empty name",
@@ -1107,7 +1110,7 @@ func TestCreateProfile(t *testing.T) {
if tc.err == nil { if tc.err == nil {
assert.NotEmpty(t, saved.ID, fmt.Sprintf("%s: expected non-empty profile ID\n", tc.desc)) assert.NotEmpty(t, saved.ID, fmt.Sprintf("%s: expected non-empty profile ID\n", tc.desc))
assert.Equal(t, domainID, saved.DomainID, fmt.Sprintf("%s: expected domain ID %s got %s\n", tc.desc, domainID, saved.DomainID)) assert.Equal(t, domainID, saved.DomainID, fmt.Sprintf("%s: expected domain ID %s got %s\n", tc.desc, domainID, saved.DomainID))
assert.Equal(t, bootstrap.TemplateFormatGoTemplate, saved.TemplateFormat, fmt.Sprintf("%s: expected go-template format\n", tc.desc)) assert.Equal(t, tc.wantFormat, saved.ContentFormat, fmt.Sprintf("%s: expected %s format\n", tc.desc, tc.wantFormat))
assert.Equal(t, 1, saved.Version, fmt.Sprintf("%s: expected version 1\n", tc.desc)) assert.Equal(t, 1, saved.Version, fmt.Sprintf("%s: expected version 1\n", tc.desc))
} }
saveCall.Unset() saveCall.Unset()
@@ -1119,11 +1122,11 @@ func TestViewProfile(t *testing.T) {
session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID} session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}
profile := bootstrap.Profile{ profile := bootstrap.Profile{
ID: testsutil.GenerateUUID(t), ID: testsutil.GenerateUUID(t),
DomainID: domainID, DomainID: domainID,
Name: "view-profile", Name: "view-profile",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
Version: 1, Version: 1,
} }
cases := []struct { cases := []struct {
@@ -1162,10 +1165,10 @@ func TestUpdateProfile(t *testing.T) {
session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID} session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}
validProfile := bootstrap.Profile{ validProfile := bootstrap.Profile{
ID: testsutil.GenerateUUID(t), ID: testsutil.GenerateUUID(t),
DomainID: domainID, DomainID: domainID,
Name: "updated-profile", Name: "updated-profile",
TemplateFormat: bootstrap.TemplateFormatGoTemplate, ContentFormat: bootstrap.ContentFormatGoTemplate,
} }
cases := []struct { cases := []struct {
@@ -1179,7 +1182,7 @@ func TestUpdateProfile(t *testing.T) {
profile: validProfile, profile: validProfile,
}, },
{ {
desc: "update profile defaults to go-template format", desc: "update profile with only name",
profile: bootstrap.Profile{ID: validProfile.ID, Name: "no-format"}, profile: bootstrap.Profile{ID: validProfile.ID, Name: "no-format"},
}, },
{ {
@@ -1223,8 +1226,8 @@ func TestUpdateProfile(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
svc := newService() svc := newService()
updateCall := profileRepo.On("Update", context.Background(), mock.Anything).Return(tc.updateErr) updateCall := profileRepo.On("Update", context.Background(), mock.Anything).Return(tc.profile, tc.updateErr)
err := svc.UpdateProfile(context.Background(), session, tc.profile) _, err := svc.UpdateProfile(context.Background(), session, tc.profile)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err))
updateCall.Unset() updateCall.Unset()
}) })
@@ -1235,15 +1238,17 @@ func TestListProfiles(t *testing.T) {
session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID} session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}
profiles := []bootstrap.Profile{ profiles := []bootstrap.Profile{
{ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p1", TemplateFormat: bootstrap.TemplateFormatGoTemplate, Version: 1}, {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p1", ContentFormat: bootstrap.ContentFormatGoTemplate, Version: 1},
{ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p2", TemplateFormat: bootstrap.TemplateFormatGoTemplate, Version: 1}, {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p2", ContentFormat: bootstrap.ContentFormatGoTemplate, Version: 1},
} }
page := bootstrap.ProfilesPage{Total: 2, Offset: 0, Limit: 10, Profiles: profiles} page := bootstrap.ProfilesPage{Total: 2, Offset: 0, Limit: 10, Profiles: profiles}
filteredPage := bootstrap.ProfilesPage{Total: 1, Offset: 0, Limit: 10, Profiles: profiles[:1]}
cases := []struct { cases := []struct {
desc string desc string
offset uint64 offset uint64
limit uint64 limit uint64
name string
page bootstrap.ProfilesPage page bootstrap.ProfilesPage
listErr error listErr error
err error err error
@@ -1253,6 +1258,12 @@ func TestListProfiles(t *testing.T) {
limit: 10, limit: 10,
page: page, page: page,
}, },
{
desc: "list profiles filtered by name",
limit: 10,
name: "p1",
page: filteredPage,
},
{ {
desc: "list profiles with repository error", desc: "list profiles with repository error",
limit: 10, limit: 10,
@@ -1264,8 +1275,8 @@ func TestListProfiles(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
svc := newService() svc := newService()
listCall := profileRepo.On("RetrieveAll", context.Background(), domainID, tc.offset, tc.limit).Return(tc.page, tc.listErr) listCall := profileRepo.On("RetrieveAll", context.Background(), domainID, tc.offset, tc.limit, tc.name).Return(tc.page, tc.listErr)
got, err := svc.ListProfiles(context.Background(), session, tc.offset, tc.limit) got, err := svc.ListProfiles(context.Background(), session, tc.offset, tc.limit, tc.name)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err))
if tc.err == nil { if tc.err == nil {
assert.Equal(t, tc.page, got, fmt.Sprintf("%s: expected page %v got %v\n", tc.desc, tc.page, got)) assert.Equal(t, tc.page, got, fmt.Sprintf("%s: expected page %v got %v\n", tc.desc, tc.page, got))
+3 -3
View File
@@ -139,7 +139,7 @@ func (tm *tracingMiddleware) ViewProfile(ctx context.Context, session smqauthn.S
return tm.svc.ViewProfile(ctx, session, profileID) return tm.svc.ViewProfile(ctx, session, profileID)
} }
func (tm *tracingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) error { func (tm *tracingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
ctx, span := tm.tracer.Start(ctx, "svc_update_profile", trace.WithAttributes( ctx, span := tm.tracer.Start(ctx, "svc_update_profile", trace.WithAttributes(
attribute.String("profile_id", p.ID), attribute.String("profile_id", p.ID),
)) ))
@@ -147,13 +147,13 @@ func (tm *tracingMiddleware) UpdateProfile(ctx context.Context, session smqauthn
return tm.svc.UpdateProfile(ctx, session, p) return tm.svc.UpdateProfile(ctx, session, p)
} }
func (tm *tracingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (bootstrap.ProfilesPage, error) { func (tm *tracingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
ctx, span := tm.tracer.Start(ctx, "svc_list_profiles", trace.WithAttributes( ctx, span := tm.tracer.Start(ctx, "svc_list_profiles", trace.WithAttributes(
attribute.Int64("offset", int64(offset)), attribute.Int64("offset", int64(offset)),
attribute.Int64("limit", int64(limit)), attribute.Int64("limit", int64(limit)),
)) ))
defer span.End() defer span.End()
return tm.svc.ListProfiles(ctx, session, offset, limit) return tm.svc.ListProfiles(ctx, session, offset, limit, name)
} }
func (tm *tracingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error { func (tm *tracingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
+3 -2
View File
@@ -272,12 +272,13 @@ var cmdBootstrap = []cobra.Command{
return return
} }
if err := sdk.UpdateBootstrapProfile(cmd.Context(), profile, args[2], args[3]); err != nil { updated, err := sdk.UpdateBootstrapProfile(cmd.Context(), profile, args[2], args[3])
if err != nil {
logErrorCmd(*cmd, err) logErrorCmd(*cmd, err)
return return
} }
logOKCmd(*cmd) logJSONCmd(*cmd, updated)
case "remove": case "remove":
if len(args) != 4 { if len(args) != 4 {
logUsageCmd(*cmd, cmd.Use) logUsageCmd(*cmd, cmd.Use)
+4 -3
View File
@@ -35,7 +35,7 @@ var (
ID: profileID, ID: profileID,
Name: "Test Profile", Name: "Test Profile",
Description: "Test profile", Description: "Test profile",
TemplateFormat: "go-template", ContentFormat: "go-template",
ContentTemplate: "{\"device_id\":\"{{ .Device.ID }}\"}", ContentTemplate: "{\"device_id\":\"{{ .Device.ID }}\"}",
Version: 1, Version: 1,
} }
@@ -707,7 +707,8 @@ func TestBootstrapProfilesCmd(t *testing.T) {
domainID, domainID,
validToken, validToken,
}, },
logType: okLog, profile: bootProfile,
logType: entityLog,
}, },
{ {
desc: "remove bootstrap profile successfully", desc: "remove bootstrap profile successfully",
@@ -729,7 +730,7 @@ func TestBootstrapProfilesCmd(t *testing.T) {
createCall := sdkMock.On("CreateBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.profile, tc.sdkErr) createCall := sdkMock.On("CreateBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.profile, tc.sdkErr)
listCall := sdkMock.On("BootstrapProfiles", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr) listCall := sdkMock.On("BootstrapProfiles", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
viewCall := sdkMock.On("ViewBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.profile, tc.sdkErr) viewCall := sdkMock.On("ViewBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.profile, tc.sdkErr)
updateCall := sdkMock.On("UpdateBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) updateCall := sdkMock.On("UpdateBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.profile, tc.sdkErr)
removeCall := sdkMock.On("RemoveBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) removeCall := sdkMock.On("RemoveBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{"profiles"}, tc.args...)...) out := executeCommand(t, rootCmd, append([]string{"profiles"}, tc.args...)...)
+15 -6
View File
@@ -109,7 +109,7 @@ type BootstrapProfile struct {
DomainID string `json:"domain_id,omitempty"` DomainID string `json:"domain_id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
TemplateFormat string `json:"template_format,omitempty"` ContentFormat string `json:"content_format,omitempty"`
ContentTemplate string `json:"content_template,omitempty"` ContentTemplate string `json:"content_template,omitempty"`
Defaults map[string]any `json:"defaults,omitempty"` Defaults map[string]any `json:"defaults,omitempty"`
BindingSlots []BindingSlot `json:"binding_slots,omitempty"` BindingSlots []BindingSlot `json:"binding_slots,omitempty"`
@@ -302,19 +302,28 @@ func (sdk mgSDK) UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domai
return sdkerr return sdkerr
} }
func (sdk mgSDK) UpdateBootstrapProfile(ctx context.Context, profile BootstrapProfile, domainID, token string) errors.SDKError { func (sdk mgSDK) UpdateBootstrapProfile(ctx context.Context, profile BootstrapProfile, domainID, token string) (BootstrapProfile, errors.SDKError) {
if profile.ID == "" { if profile.ID == "" {
return errors.NewSDKError(apiutil.ErrMissingID) return BootstrapProfile{}, errors.NewSDKError(apiutil.ErrMissingID)
} }
url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, bootstrapProfilesPath, profile.ID) url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, bootstrapProfilesPath, profile.ID)
data, err := json.Marshal(profile) data, err := json.Marshal(profile)
if err != nil { if err != nil {
return errors.NewSDKError(err) return BootstrapProfile{}, errors.NewSDKError(err)
} }
_, _, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK)
return sdkerr if sdkerr != nil {
return BootstrapProfile{}, sdkerr
}
var updated BootstrapProfile
if err := json.Unmarshal(body, &updated); err != nil {
return BootstrapProfile{}, errors.NewSDKError(err)
}
return updated, nil
} }
func (sdk mgSDK) UpdateBootstrapCerts(ctx context.Context, id, clientCert, clientKey, ca, domainID, token string) (BootstrapConfig, errors.SDKError) { func (sdk mgSDK) UpdateBootstrapCerts(ctx context.Context, id, clientCert, clientKey, ca, domainID, token string) (BootstrapConfig, errors.SDKError) {
+306
View File
@@ -908,6 +908,312 @@ func TestBootstrapSecure(t *testing.T) {
} }
} }
func TestCreateBootstrapProfile(t *testing.T) {
bs, bsvc, _, auth := setupBootstrap()
defer bs.Close()
mgsdk := sdk.NewSDK(sdk.Config{BootstrapURL: bs.URL})
profile := sdk.BootstrapProfile{
Name: "gateway-profile",
ContentFormat: "go-template",
}
saved := bootstrap.Profile{
ID: testsutil.GenerateUUID(t),
DomainID: domainID,
Name: "gateway-profile",
ContentFormat: bootstrap.ContentFormatGoTemplate,
}
cases := []struct {
desc string
token string
profile sdk.BootstrapProfile
svcResp bootstrap.Profile
svcErr error
authErr error
expectedSDKErr errors.SDKError
}{
{
desc: "create profile successfully",
token: validToken,
profile: profile,
svcResp: saved,
},
{
desc: "create profile with invalid token",
token: invalidToken,
profile: profile,
authErr: svcerr.ErrAuthentication,
expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "create profile with empty name",
token: validToken,
profile: sdk.BootstrapProfile{},
expectedSDKErr: errors.NewSDKErrorWithStatus(apiutil.ErrMissingName, http.StatusBadRequest),
},
{
desc: "create profile with service error",
token: validToken,
profile: profile,
svcErr: svcerr.ErrCreateEntity,
expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrCreateEntity, http.StatusUnprocessableEntity),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
session := smqauthn.Session{}
if tc.token == validToken {
session = bootstrapSession()
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(session, tc.authErr)
svcCall := bsvc.On("CreateProfile", mock.Anything, session, mock.Anything).Return(tc.svcResp, tc.svcErr)
_, err := mgsdk.CreateBootstrapProfile(context.Background(), tc.profile, domainID, tc.token)
assert.Equal(t, tc.expectedSDKErr, err)
svcCall.Unset()
authCall.Unset()
})
}
}
func TestViewBootstrapProfile(t *testing.T) {
bs, bsvc, _, auth := setupBootstrap()
defer bs.Close()
mgsdk := sdk.NewSDK(sdk.Config{BootstrapURL: bs.URL})
profileID := testsutil.GenerateUUID(t)
saved := bootstrap.Profile{
ID: profileID,
DomainID: domainID,
Name: "gateway-profile",
ContentFormat: bootstrap.ContentFormatGoTemplate,
}
expected := sdk.BootstrapProfile{
ID: profileID,
DomainID: domainID,
Name: "gateway-profile",
ContentFormat: "go-template",
}
cases := []struct {
desc string
token string
profileID string
svcResp bootstrap.Profile
svcErr error
authErr error
expectedResp sdk.BootstrapProfile
expectedSDKErr errors.SDKError
}{
{
desc: "view profile successfully",
token: validToken,
profileID: profileID,
svcResp: saved,
expectedResp: expected,
},
{
desc: "view profile with invalid token",
token: invalidToken,
profileID: profileID,
authErr: svcerr.ErrAuthentication,
expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "view profile with empty id",
token: validToken,
profileID: "",
expectedSDKErr: errors.NewSDKError(apiutil.ErrMissingID),
},
{
desc: "view profile not found",
token: validToken,
profileID: profileID,
svcErr: svcerr.ErrNotFound,
expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
session := smqauthn.Session{}
if tc.token == validToken {
session = bootstrapSession()
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(session, tc.authErr)
svcCall := bsvc.On("ViewProfile", mock.Anything, session, tc.profileID).Return(tc.svcResp, tc.svcErr)
resp, err := mgsdk.ViewBootstrapProfile(context.Background(), tc.profileID, domainID, tc.token)
assert.Equal(t, tc.expectedSDKErr, err)
if tc.expectedSDKErr == nil {
assert.Equal(t, tc.expectedResp, resp)
}
svcCall.Unset()
authCall.Unset()
})
}
}
func TestUpdateBootstrapProfile(t *testing.T) {
bs, bsvc, _, auth := setupBootstrap()
defer bs.Close()
mgsdk := sdk.NewSDK(sdk.Config{BootstrapURL: bs.URL})
profileID := testsutil.GenerateUUID(t)
updatedProfile := bootstrap.Profile{
ID: profileID,
DomainID: domainID,
Name: "updated-name",
ContentFormat: bootstrap.ContentFormatYAML,
}
expectedResp := sdk.BootstrapProfile{
ID: profileID,
DomainID: domainID,
Name: "updated-name",
ContentFormat: "yaml",
}
cases := []struct {
desc string
token string
profile sdk.BootstrapProfile
svcResp bootstrap.Profile
svcErr error
authErr error
expectedResp sdk.BootstrapProfile
expectedSDKErr errors.SDKError
}{
{
desc: "update profile successfully",
token: validToken,
profile: sdk.BootstrapProfile{
ID: profileID,
Name: "updated-name",
ContentFormat: "yaml",
},
svcResp: updatedProfile,
expectedResp: expectedResp,
},
{
desc: "update profile with invalid token",
token: invalidToken,
profile: sdk.BootstrapProfile{
ID: profileID,
},
authErr: svcerr.ErrAuthentication,
expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "update profile with empty id",
token: validToken,
profile: sdk.BootstrapProfile{},
expectedSDKErr: errors.NewSDKError(apiutil.ErrMissingID),
},
{
desc: "update profile not found",
token: validToken,
profile: sdk.BootstrapProfile{
ID: profileID,
},
svcErr: svcerr.ErrNotFound,
expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
session := smqauthn.Session{}
if tc.token == validToken {
session = bootstrapSession()
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(session, tc.authErr)
svcCall := bsvc.On("UpdateProfile", mock.Anything, session, mock.Anything).Return(tc.svcResp, tc.svcErr)
resp, err := mgsdk.UpdateBootstrapProfile(context.Background(), tc.profile, domainID, tc.token)
assert.Equal(t, tc.expectedSDKErr, err)
if tc.expectedSDKErr == nil {
assert.Equal(t, tc.expectedResp, resp)
}
svcCall.Unset()
authCall.Unset()
})
}
}
func TestBootstrapProfiles(t *testing.T) {
bs, bsvc, _, auth := setupBootstrap()
defer bs.Close()
mgsdk := sdk.NewSDK(sdk.Config{BootstrapURL: bs.URL})
profiles := bootstrap.ProfilesPage{
Total: 2,
Offset: 0,
Limit: 10,
Profiles: []bootstrap.Profile{
{ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p1", ContentFormat: bootstrap.ContentFormatGoTemplate},
{ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p2", ContentFormat: bootstrap.ContentFormatYAML},
},
}
cases := []struct {
desc string
token string
pageMeta sdk.PageMetadata
svcResp bootstrap.ProfilesPage
svcErr error
authErr error
expectedCount int
expectedSDKErr errors.SDKError
}{
{
desc: "list profiles successfully",
token: validToken,
pageMeta: sdk.PageMetadata{Offset: 0, Limit: 10},
svcResp: profiles,
expectedCount: 2,
},
{
desc: "list profiles filtered by name",
token: validToken,
pageMeta: sdk.PageMetadata{Offset: 0, Limit: 10, Name: "p1"},
svcResp: bootstrap.ProfilesPage{Total: 1, Profiles: profiles.Profiles[:1]},
expectedCount: 1,
},
{
desc: "list profiles with invalid token",
token: invalidToken,
pageMeta: sdk.PageMetadata{Offset: 0, Limit: 10},
authErr: svcerr.ErrAuthentication,
expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
session := smqauthn.Session{}
if tc.token == validToken {
session = bootstrapSession()
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(session, tc.authErr)
svcCall := bsvc.On("ListProfiles", mock.Anything, session, mock.Anything, mock.Anything, mock.Anything).Return(tc.svcResp, tc.svcErr)
resp, err := mgsdk.BootstrapProfiles(context.Background(), tc.pageMeta, domainID, tc.token)
assert.Equal(t, tc.expectedSDKErr, err)
if tc.expectedSDKErr == nil {
assert.Equal(t, tc.expectedCount, len(resp.Profiles))
}
svcCall.Unset()
authCall.Unset()
})
}
}
func encrypt(in, encKey []byte) ([]byte, error) { func encrypt(in, encKey []byte) ([]byte, error) {
block, err := aes.NewCipher(encKey) block, err := aes.NewCipher(encKey)
if err != nil { if err != nil {
+18 -9
View File
@@ -11793,22 +11793,31 @@ func (_c *SDK_UpdateBootstrapConnection_Call) RunAndReturn(run func(ctx context.
} }
// UpdateBootstrapProfile provides a mock function for the type SDK // UpdateBootstrapProfile provides a mock function for the type SDK
func (_mock *SDK) UpdateBootstrapProfile(ctx context.Context, profile sdk.BootstrapProfile, domainID string, token string) errors.SDKError { func (_mock *SDK) UpdateBootstrapProfile(ctx context.Context, profile sdk.BootstrapProfile, domainID string, token string) (sdk.BootstrapProfile, errors.SDKError) {
ret := _mock.Called(ctx, profile, domainID, token) ret := _mock.Called(ctx, profile, domainID, token)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for UpdateBootstrapProfile") panic("no return value specified for UpdateBootstrapProfile")
} }
var r0 errors.SDKError var r0 sdk.BootstrapProfile
if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapProfile, string, string) errors.SDKError); ok { var r1 errors.SDKError
if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapProfile, string, string) (sdk.BootstrapProfile, errors.SDKError)); ok {
return returnFunc(ctx, profile, domainID, token)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapProfile, string, string) sdk.BootstrapProfile); ok {
r0 = returnFunc(ctx, profile, domainID, token) r0 = returnFunc(ctx, profile, domainID, token)
} else { } else {
if ret.Get(0) != nil { r0 = ret.Get(0).(sdk.BootstrapProfile)
r0 = ret.Get(0).(errors.SDKError) }
if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.BootstrapProfile, string, string) errors.SDKError); ok {
r1 = returnFunc(ctx, profile, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
} }
} }
return r0 return r0, r1
} }
// SDK_UpdateBootstrapProfile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBootstrapProfile' // SDK_UpdateBootstrapProfile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBootstrapProfile'
@@ -11853,12 +11862,12 @@ func (_c *SDK_UpdateBootstrapProfile_Call) Run(run func(ctx context.Context, pro
return _c return _c
} }
func (_c *SDK_UpdateBootstrapProfile_Call) Return(sDKError errors.SDKError) *SDK_UpdateBootstrapProfile_Call { func (_c *SDK_UpdateBootstrapProfile_Call) Return(bootstrapProfile sdk.BootstrapProfile, sDKError errors.SDKError) *SDK_UpdateBootstrapProfile_Call {
_c.Call.Return(sDKError) _c.Call.Return(bootstrapProfile, sDKError)
return _c return _c
} }
func (_c *SDK_UpdateBootstrapProfile_Call) RunAndReturn(run func(ctx context.Context, profile sdk.BootstrapProfile, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrapProfile_Call { func (_c *SDK_UpdateBootstrapProfile_Call) RunAndReturn(run func(ctx context.Context, profile sdk.BootstrapProfile, domainID string, token string) (sdk.BootstrapProfile, errors.SDKError)) *SDK_UpdateBootstrapProfile_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
+2 -2
View File
@@ -1631,8 +1631,8 @@ type SDK interface {
// UpdateBootstrap updates editable fields of the provided Config. // UpdateBootstrap updates editable fields of the provided Config.
UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) smqerrors.SDKError UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) smqerrors.SDKError
// UpdateBootstrapProfile updates editable fields of the provided bootstrap profile. // UpdateBootstrapProfile updates editable fields of the provided bootstrap profile and returns the updated profile.
UpdateBootstrapProfile(ctx context.Context, profile BootstrapProfile, domainID, token string) smqerrors.SDKError UpdateBootstrapProfile(ctx context.Context, profile BootstrapProfile, domainID, token string) (BootstrapProfile, smqerrors.SDKError)
// UpdateBootstrapCerts updates bootstrap config certificates. // UpdateBootstrapCerts updates bootstrap config certificates.
UpdateBootstrapCerts(ctx context.Context, id string, clientCert, clientKey, ca string, domainID, token string) (BootstrapConfig, smqerrors.SDKError) UpdateBootstrapCerts(ctx context.Context, id string, clientCert, clientKey, ca string, domainID, token string) (BootstrapConfig, smqerrors.SDKError)