From 683809dc6baf3e0cfc24e0909361e7692649c81d Mon Sep 17 00:00:00 2001 From: Steve Munene Date: Tue, 19 May 2026 10:02:45 +0300 Subject: [PATCH] NOISSUE - Update bootstrap content format, update profile method and add profile search (#3515) Signed-off-by: nyagamunene --- bootstrap/api/endpoint.go | 23 +- bootstrap/api/endpoint_test.go | 258 ++++++++++++++++++---- bootstrap/api/requests.go | 2 + bootstrap/api/transport.go | 14 +- bootstrap/binding_validation.go | 2 +- bootstrap/events/producer/streams.go | 15 +- bootstrap/middleware/authorization.go | 8 +- bootstrap/middleware/logging.go | 6 +- bootstrap/middleware/metrics.go | 6 +- bootstrap/mocks/profile_repository.go | 55 +++-- bootstrap/mocks/service.go | 55 +++-- bootstrap/postgres/configs_test.go | 10 +- bootstrap/postgres/init.go | 9 + bootstrap/postgres/profiles.go | 142 ++++++++---- bootstrap/profiles.go | 24 +- bootstrap/renderer.go | 81 ++++--- bootstrap/renderer_test.go | 43 ++-- bootstrap/service.go | 32 ++- bootstrap/service_test.go | 77 ++++--- bootstrap/tracing/tracing.go | 6 +- cli/bootstrap.go | 5 +- cli/bootstrap_test.go | 7 +- pkg/sdk/bootstrap.go | 21 +- pkg/sdk/bootstrap_test.go | 306 ++++++++++++++++++++++++++ pkg/sdk/mocks/sdk.go | 27 ++- pkg/sdk/sdk.go | 4 +- 26 files changed, 952 insertions(+), 286 deletions(-) diff --git a/bootstrap/api/endpoint.go b/bootstrap/api/endpoint.go index 7f1f98630..95248640e 100644 --- a/bootstrap/api/endpoint.go +++ b/bootstrap/api/endpoint.go @@ -353,13 +353,27 @@ func renderPreviewEndpoint(svc bootstrap.Service) endpoint.Endpoint { } cfg := req.Config + bindings := req.Bindings + + if req.ConfigID != "" { + stored, err := svc.View(ctx, session, req.ConfigID) + if err != nil { + return nil, err + } + cfg = stored + bindings, err = svc.ListBindings(ctx, session, req.ConfigID) + if err != nil { + return nil, err + } + } + cfg.DomainID = session.DomainID cfg.ProfileID = p.ID if cfg.RenderContext == nil { cfg.RenderContext = req.RenderContext } - rendered, err := bootstrap.NewRenderer().Render(p, cfg, req.Bindings) + rendered, err := bootstrap.NewRenderer().Render(p, cfg, bindings) if err != nil { return nil, err } @@ -379,10 +393,11 @@ func updateProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint { return nil, svcerr.ErrAuthorization } req.Profile.ID = req.profileID - if err := svc.UpdateProfile(ctx, session, req.Profile); err != nil { + updated, err := svc.UpdateProfile(ctx, session, req.Profile) + if err != nil { return nil, err } - return profileRes{Profile: req.Profile}, nil + return profileRes{Profile: updated}, nil } } @@ -413,7 +428,7 @@ func listProfilesEndpoint(svc bootstrap.Service) endpoint.Endpoint { if !ok { return nil, svcerr.ErrAuthorization } - page, err := svc.ListProfiles(ctx, session, req.offset, req.limit) + page, err := svc.ListProfiles(ctx, session, req.offset, req.limit, req.name) if err != nil { return nil, err } diff --git a/bootstrap/api/endpoint_test.go b/bootstrap/api/endpoint_test.go index 6568ad68c..18a997a22 100644 --- a/bootstrap/api/endpoint_test.go +++ b/bootstrap/api/endpoint_test.go @@ -1182,7 +1182,7 @@ func TestUploadProfile(t *testing.T) { saved := bootstrap.Profile{ ID: testsutil.GenerateUUID(t), Name: "gateway", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, + ContentFormat: bootstrap.ContentFormatGoTemplate, ContentTemplate: "{{ .Device.ID }}", } @@ -1195,30 +1195,60 @@ func TestUploadProfile(t *testing.T) { { desc: "upload JSON profile", contentType: "application/json", - body: `{"name":"gateway","template_format":"go-template","content_template":"{{ .Device.ID }}"}`, + body: `{"name":"gateway","content_format":"go-template","content_template":"{{ .Device.ID }}"}`, profile: bootstrap.Profile{ Name: "gateway", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, + ContentFormat: bootstrap.ContentFormatGoTemplate, ContentTemplate: "{{ .Device.ID }}", }, }, { desc: "upload YAML profile", contentType: "application/yaml", - body: "name: gateway\ntemplate_format: go-template\ncontent_template: '{{ .Device.ID }}'\n", + body: "name: gateway\ncontent_format: go-template\ncontent_template: '{{ .Device.ID }}'\n", profile: bootstrap.Profile{ Name: "gateway", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, + ContentFormat: bootstrap.ContentFormatGoTemplate, ContentTemplate: "{{ .Device.ID }}", }, }, { desc: "upload TOML profile", contentType: "application/toml", - body: "name = 'gateway'\ntemplate_format = 'go-template'\ncontent_template = '{{ .Device.ID }}'\n", + body: "name = 'gateway'\ncontent_format = 'go-template'\ncontent_template = '{{ .Device.ID }}'\n", profile: bootstrap.Profile{ Name: "gateway", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, + ContentFormat: bootstrap.ContentFormatGoTemplate, + ContentTemplate: "{{ .Device.ID }}", + }, + }, + { + desc: "upload JSON profile without content_format infers json", + contentType: "application/json", + body: `{"name":"gateway","content_template":"{{ .Device.ID }}"}`, + profile: bootstrap.Profile{ + Name: "gateway", + ContentFormat: bootstrap.ContentFormatJSON, + ContentTemplate: "{{ .Device.ID }}", + }, + }, + { + desc: "upload YAML profile without content_format infers yaml", + contentType: "application/yaml", + body: "name: gateway\ncontent_template: '{{ .Device.ID }}'\n", + profile: bootstrap.Profile{ + Name: "gateway", + ContentFormat: bootstrap.ContentFormatYAML, + ContentTemplate: "{{ .Device.ID }}", + }, + }, + { + desc: "upload TOML profile without content_format infers toml", + contentType: "application/toml", + body: "name = 'gateway'\ncontent_template = '{{ .Device.ID }}'\n", + profile: bootstrap.Profile{ + Name: "gateway", + ContentFormat: bootstrap.ContentFormatTOML, ContentTemplate: "{{ .Device.ID }}", }, }, @@ -1246,6 +1276,83 @@ func TestUploadProfile(t *testing.T) { } } +func TestListProfiles(t *testing.T) { + bs, svc, auth := newBootstrapServer() + defer bs.Close() + + session := smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + path := fmt.Sprintf("%s/%s/clients/bootstrap/profiles", bs.URL, domainID) + + profiles := []bootstrap.Profile{ + {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "gateway-profile"}, + {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "sensor-profile"}, + } + fullPage := bootstrap.ProfilesPage{Total: 2, Offset: 0, Limit: 10, Profiles: profiles} + filteredPage := bootstrap.ProfilesPage{Total: 1, Offset: 0, Limit: 10, Profiles: profiles[:1]} + + cases := []struct { + desc string + token string + session smqauthn.Session + url string + name string + svcPage bootstrap.ProfilesPage + svcErr error + authenticateErr error + status int + }{ + { + desc: "list profiles successfully", + token: validToken, + session: session, + url: fmt.Sprintf("%s?offset=0&limit=10", path), + svcPage: fullPage, + status: http.StatusOK, + }, + { + desc: "list profiles filtered by name", + token: validToken, + session: session, + url: fmt.Sprintf("%s?offset=0&limit=10&name=gateway-profile", path), + name: "gateway-profile", + svcPage: filteredPage, + status: http.StatusOK, + }, + { + desc: "list profiles with invalid token", + token: invalidToken, + url: fmt.Sprintf("%s?offset=0&limit=10", path), + authenticateErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "list profiles with limit exceeding max", + token: validToken, + session: session, + url: fmt.Sprintf("%s?offset=0&limit=101", path), + status: http.StatusBadRequest, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("ListProfiles", mock.Anything, tc.session, mock.Anything, mock.Anything, tc.name).Return(tc.svcPage, tc.svcErr) + req := testRequest{ + client: bs.Client(), + method: http.MethodGet, + url: tc.url, + token: tc.token, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status %d got %d", tc.desc, tc.status, res.StatusCode)) + authCall.Unset() + svcCall.Unset() + }) + } +} + func TestProfileSlots(t *testing.T) { bs, svc, auth := newBootstrapServer() defer bs.Close() @@ -1295,56 +1402,119 @@ func TestRenderPreview(t *testing.T) { profile := bootstrap.Profile{ ID: profileID, Name: "gateway", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, + ContentFormat: bootstrap.ContentFormatGoTemplate, ContentTemplate: `device={{ .Device.ID }} site={{ .Vars.site }} topic={{ index (index .Bindings "telemetry").Snapshot "topic" }}`, } - authCall := auth.On("Authenticate", mock.Anything, validToken).Return(session, nil) - svcCall := svc.On("ViewProfile", mock.Anything, session, profileID).Return(profile, nil) - reqBody := struct { + storedConfig := bootstrap.Config{ + ID: configID, + ExternalID: "gw-001", + DomainID: domainID, + RenderContext: map[string]any{ + "site": "warehouse-1", + }, + } + storedBindings := []bootstrap.BindingSnapshot{ + { + Slot: "telemetry", + Type: "channel", + ResourceID: "ch-1", + Snapshot: map[string]any{"topic": "devices/gw-001/telemetry"}, + }, + } + + inlineReqBody := struct { Config bootstrap.Config `json:"config"` Bindings []bootstrap.BindingSnapshot `json:"bindings"` }{ Config: bootstrap.Config{ - ID: configID, - ExternalID: "gw-001", - RenderContext: map[string]any{ - "site": "warehouse-1", - }, + ID: configID, + ExternalID: "gw-001", + RenderContext: map[string]any{"site": "warehouse-1"}, }, - Bindings: []bootstrap.BindingSnapshot{ - { - Slot: "telemetry", - Type: "channel", - ResourceID: "ch-1", - Snapshot: map[string]any{ - "topic": "devices/gw-001/telemetry", - }, - }, + Bindings: storedBindings, + } + + configIDReqBody := struct { + ConfigID string `json:"config_id"` + }{ + ConfigID: configID, + } + + expectedContent := "device=" + configID + " site=warehouse-1 topic=devices/gw-001/telemetry" + + cases := []struct { + desc string + body string + profileErr error + configErr error + bindingsErr error + status int + }{ + { + desc: "render preview with inline config and bindings", + body: toJSON(inlineReqBody), + status: http.StatusOK, + }, + { + desc: "render preview with config_id loads from db", + body: toJSON(configIDReqBody), + status: http.StatusOK, + }, + { + desc: "render preview with config_id and config not found", + body: toJSON(configIDReqBody), + configErr: svcerr.ErrNotFound, + status: http.StatusNotFound, + }, + { + desc: "render preview with config_id and bindings error", + body: toJSON(configIDReqBody), + bindingsErr: svcerr.ErrViewEntity, + status: http.StatusUnprocessableEntity, + }, + { + desc: "render preview with profile not found", + body: toJSON(inlineReqBody), + profileErr: svcerr.ErrNotFound, + status: http.StatusNotFound, }, } - req := testRequest{ - client: bs.Client(), - method: http.MethodPost, - url: fmt.Sprintf("%s/%s/clients/bootstrap/profiles/%s/render-preview", bs.URL, domainID, profileID), - contentType: contentType, - token: validToken, - body: strings.NewReader(toJSON(reqBody)), - } - res, err := req.make() - assert.Nil(t, err, fmt.Sprintf("render preview unexpected error %s", err)) - assert.Equal(t, http.StatusOK, res.StatusCode, fmt.Sprintf("expected status code %d got %d", http.StatusOK, res.StatusCode)) + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + authCall := auth.On("Authenticate", mock.Anything, validToken).Return(session, nil) + svcCall := svc.On("ViewProfile", mock.Anything, session, profileID).Return(profile, tc.profileErr) + svcCall2 := svc.On("View", mock.Anything, session, configID).Return(storedConfig, tc.configErr) + svcCall3 := svc.On("ListBindings", mock.Anything, session, configID).Return(storedBindings, tc.bindingsErr) - var got struct { - Content string `json:"content"` - } - err = json.NewDecoder(res.Body).Decode(&got) - assert.Nil(t, err, fmt.Sprintf("decoding render preview expected to succeed: %s", err)) - assert.Equal(t, "device="+configID+" site=warehouse-1 topic=devices/gw-001/telemetry", got.Content) + req := testRequest{ + client: bs.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/clients/bootstrap/profiles/%s/render-preview", bs.URL, domainID, profileID), + contentType: contentType, + token: validToken, + body: strings.NewReader(tc.body), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - svcCall.Unset() - authCall.Unset() + if tc.status == http.StatusOK { + var got struct { + Content string `json:"content"` + } + err = json.NewDecoder(res.Body).Decode(&got) + assert.Nil(t, err, fmt.Sprintf("%s: decoding expected to succeed: %s", tc.desc, err)) + assert.Equal(t, expectedContent, got.Content, fmt.Sprintf("%s: expected content %q got %q", tc.desc, expectedContent, got.Content)) + } + + svcCall3.Unset() + svcCall2.Unset() + svcCall.Unset() + authCall.Unset() + }) + } } type config struct { diff --git a/bootstrap/api/requests.go b/bootstrap/api/requests.go index 199692251..1863539ec 100644 --- a/bootstrap/api/requests.go +++ b/bootstrap/api/requests.go @@ -178,6 +178,7 @@ func (req updateProfileReq) validate() error { type renderPreviewReq struct { profileID string + ConfigID string `json:"config_id,omitempty"` Config bootstrap.Config `json:"config"` RenderContext map[string]any `json:"render_context,omitempty"` Bindings []bootstrap.BindingSnapshot `json:"bindings,omitempty"` @@ -204,6 +205,7 @@ func (req deleteProfileReq) validate() error { type listProfilesReq struct { offset uint64 limit uint64 + name string } func (req listProfilesReq) validate() error { diff --git a/bootstrap/api/transport.go b/bootstrap/api/transport.go index 8e3f685d8..9d6a1ed0d 100644 --- a/bootstrap/api/transport.go +++ b/bootstrap/api/transport.go @@ -368,13 +368,16 @@ func decodeCreateProfileRequest(_ context.Context, r *http.Request) (any, error) func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error) { contentType := r.Header.Get("Content-Type") var req uploadProfileReq + var inferredFormat bootstrap.ContentFormat switch { case strings.Contains(contentType, "json"): + inferredFormat = bootstrap.ContentFormatJSON if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil { return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) } case strings.Contains(contentType, yamlContentType): + inferredFormat = bootstrap.ContentFormatYAML body, err := io.ReadAll(r.Body) if err != nil { return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) @@ -383,6 +386,7 @@ func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error) return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) } case strings.Contains(contentType, tomlContentType): + inferredFormat = bootstrap.ContentFormatTOML body, err := io.ReadAll(r.Body) if err != nil { return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) @@ -394,6 +398,10 @@ func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error) return nil, apiutil.ErrUnsupportedContentType } + if req.Profile.ContentFormat == "" { + req.Profile.ContentFormat = inferredFormat + } + return req, nil } @@ -430,7 +438,11 @@ func decodeListProfilesRequest(_ context.Context, r *http.Request) (any, error) if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - return listProfilesReq{offset: o, limit: l}, nil + n, err := apiutil.ReadStringQuery(r, api.NameKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + return listProfilesReq{offset: o, limit: l, name: n}, nil } func decodeProfileEntityRequest(_ context.Context, r *http.Request) (any, error) { diff --git a/bootstrap/binding_validation.go b/bootstrap/binding_validation.go index d5baaeb80..34bcbe0fd 100644 --- a/bootstrap/binding_validation.go +++ b/bootstrap/binding_validation.go @@ -98,7 +98,7 @@ func mergeBindingSnapshots(existing, updated []BindingSnapshot) []BindingSnapsho } func validateProfileTemplate(p Profile) error { - if p.ContentTemplate == "" || p.TemplateFormat == TemplateFormatRaw { + if p.ContentTemplate == "" || p.ContentFormat == ContentFormatRaw { return nil } _, err := template.New("bootstrap").Funcs(allowlistedFuncs()).Parse(p.ContentTemplate) diff --git a/bootstrap/events/producer/streams.go b/bootstrap/events/producer/streams.go index 969cb8e5e..f024068f6 100644 --- a/bootstrap/events/producer/streams.go +++ b/bootstrap/events/producer/streams.go @@ -214,16 +214,17 @@ func (es *eventStore) ViewProfile(ctx context.Context, session smqauthn.Session, return p, nil } -func (es *eventStore) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) error { - if err := es.svc.UpdateProfile(ctx, session, p); err != nil { - return err +func (es *eventStore) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) { + updated, err := es.svc.UpdateProfile(ctx, session, p) + if err != nil { + return bootstrap.Profile{}, err } - ev := profileEvent{p, profileUpdate} - return es.Publish(ctx, updateProfileStream, ev) + ev := profileEvent{updated, profileUpdate} + return updated, es.Publish(ctx, updateProfileStream, ev) } -func (es *eventStore) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (bootstrap.ProfilesPage, error) { - pp, err := es.svc.ListProfiles(ctx, session, offset, limit) +func (es *eventStore) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) { + pp, err := es.svc.ListProfiles(ctx, session, offset, limit, name) if err != nil { return pp, err } diff --git a/bootstrap/middleware/authorization.go b/bootstrap/middleware/authorization.go index d5fe16272..22cedf5a8 100644 --- a/bootstrap/middleware/authorization.go +++ b/bootstrap/middleware/authorization.go @@ -124,18 +124,18 @@ func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session smqa return am.svc.ViewProfile(ctx, session, profileID) } -func (am *authorizationMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) error { +func (am *authorizationMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) { if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, auth.AnyIDs); err != nil { - return err + return bootstrap.Profile{}, err } return am.svc.UpdateProfile(ctx, session, p) } -func (am *authorizationMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (bootstrap.ProfilesPage, error) { +func (am *authorizationMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) { if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, listOperation, auth.AnyIDs); err != nil { return bootstrap.ProfilesPage{}, err } - return am.svc.ListProfiles(ctx, session, offset, limit) + return am.svc.ListProfiles(ctx, session, offset, limit, name) } func (am *authorizationMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error { diff --git a/bootstrap/middleware/logging.go b/bootstrap/middleware/logging.go index 1e023adb2..5be17d747 100644 --- a/bootstrap/middleware/logging.go +++ b/bootstrap/middleware/logging.go @@ -233,7 +233,7 @@ func (lm *loggingMiddleware) ViewProfile(ctx context.Context, session smqauthn.S return lm.svc.ViewProfile(ctx, session, profileID) } -func (lm *loggingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (err error) { +func (lm *loggingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (updated bootstrap.Profile, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -250,7 +250,7 @@ func (lm *loggingMiddleware) UpdateProfile(ctx context.Context, session smqauthn return lm.svc.UpdateProfile(ctx, session, p) } -func (lm *loggingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (page bootstrap.ProfilesPage, err error) { +func (lm *loggingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (page bootstrap.ProfilesPage, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -265,7 +265,7 @@ func (lm *loggingMiddleware) ListProfiles(ctx context.Context, session smqauthn. lm.logger.Info("List profiles completed successfully", args...) }(time.Now()) - return lm.svc.ListProfiles(ctx, session, offset, limit) + return lm.svc.ListProfiles(ctx, session, offset, limit, name) } func (lm *loggingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) (err error) { diff --git a/bootstrap/middleware/metrics.go b/bootstrap/middleware/metrics.go index 5dbe44a70..801b5eb1b 100644 --- a/bootstrap/middleware/metrics.go +++ b/bootstrap/middleware/metrics.go @@ -135,7 +135,7 @@ func (mm *metricsMiddleware) ViewProfile(ctx context.Context, session smqauthn.S return mm.svc.ViewProfile(ctx, session, profileID) } -func (mm *metricsMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) error { +func (mm *metricsMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) { defer func(begin time.Time) { mm.counter.With("method", "update_profile").Add(1) mm.latency.With("method", "update_profile").Observe(time.Since(begin).Seconds()) @@ -143,12 +143,12 @@ func (mm *metricsMiddleware) UpdateProfile(ctx context.Context, session smqauthn return mm.svc.UpdateProfile(ctx, session, p) } -func (mm *metricsMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (bootstrap.ProfilesPage, error) { +func (mm *metricsMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) { defer func(begin time.Time) { mm.counter.With("method", "list_profiles").Add(1) mm.latency.With("method", "list_profiles").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.ListProfiles(ctx, session, offset, limit) + return mm.svc.ListProfiles(ctx, session, offset, limit, name) } func (mm *metricsMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error { diff --git a/bootstrap/mocks/profile_repository.go b/bootstrap/mocks/profile_repository.go index b6b1013df..0fcec9c0b 100644 --- a/bootstrap/mocks/profile_repository.go +++ b/bootstrap/mocks/profile_repository.go @@ -106,8 +106,8 @@ func (_c *ProfileRepository_Delete_Call) RunAndReturn(run func(ctx context.Conte } // RetrieveAll provides a mock function for the type ProfileRepository -func (_mock *ProfileRepository) RetrieveAll(ctx context.Context, domainID string, offset uint64, limit uint64) (bootstrap.ProfilesPage, error) { - ret := _mock.Called(ctx, domainID, offset, limit) +func (_mock *ProfileRepository) RetrieveAll(ctx context.Context, domainID string, offset uint64, limit uint64, name string) (bootstrap.ProfilesPage, error) { + ret := _mock.Called(ctx, domainID, offset, limit, name) if len(ret) == 0 { panic("no return value specified for RetrieveAll") @@ -115,16 +115,16 @@ func (_mock *ProfileRepository) RetrieveAll(ctx context.Context, domainID string var r0 bootstrap.ProfilesPage var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) (bootstrap.ProfilesPage, error)); ok { - return returnFunc(ctx, domainID, offset, limit) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64, string) (bootstrap.ProfilesPage, error)); ok { + return returnFunc(ctx, domainID, offset, limit, name) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) bootstrap.ProfilesPage); ok { - r0 = returnFunc(ctx, domainID, offset, limit) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64, string) bootstrap.ProfilesPage); ok { + r0 = returnFunc(ctx, domainID, offset, limit, name) } else { r0 = ret.Get(0).(bootstrap.ProfilesPage) } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint64, uint64) error); ok { - r1 = returnFunc(ctx, domainID, offset, limit) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint64, uint64, string) error); ok { + r1 = returnFunc(ctx, domainID, offset, limit, name) } else { r1 = ret.Error(1) } @@ -141,11 +141,12 @@ type ProfileRepository_RetrieveAll_Call struct { // - domainID string // - offset uint64 // - limit uint64 -func (_e *ProfileRepository_Expecter) RetrieveAll(ctx interface{}, domainID interface{}, offset interface{}, limit interface{}) *ProfileRepository_RetrieveAll_Call { - return &ProfileRepository_RetrieveAll_Call{Call: _e.mock.On("RetrieveAll", ctx, domainID, offset, limit)} +// - name string +func (_e *ProfileRepository_Expecter) RetrieveAll(ctx interface{}, domainID interface{}, offset interface{}, limit interface{}, name interface{}) *ProfileRepository_RetrieveAll_Call { + return &ProfileRepository_RetrieveAll_Call{Call: _e.mock.On("RetrieveAll", ctx, domainID, offset, limit, name)} } -func (_c *ProfileRepository_RetrieveAll_Call) Run(run func(ctx context.Context, domainID string, offset uint64, limit uint64)) *ProfileRepository_RetrieveAll_Call { +func (_c *ProfileRepository_RetrieveAll_Call) Run(run func(ctx context.Context, domainID string, offset uint64, limit uint64, name string)) *ProfileRepository_RetrieveAll_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -163,11 +164,16 @@ func (_c *ProfileRepository_RetrieveAll_Call) Run(run func(ctx context.Context, if args[3] != nil { arg3 = args[3].(uint64) } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } run( arg0, arg1, arg2, arg3, + arg4, ) }) return _c @@ -178,7 +184,7 @@ func (_c *ProfileRepository_RetrieveAll_Call) Return(profilesPage bootstrap.Prof return _c } -func (_c *ProfileRepository_RetrieveAll_Call) RunAndReturn(run func(ctx context.Context, domainID string, offset uint64, limit uint64) (bootstrap.ProfilesPage, error)) *ProfileRepository_RetrieveAll_Call { +func (_c *ProfileRepository_RetrieveAll_Call) RunAndReturn(run func(ctx context.Context, domainID string, offset uint64, limit uint64, name string) (bootstrap.ProfilesPage, error)) *ProfileRepository_RetrieveAll_Call { _c.Call.Return(run) return _c } @@ -322,20 +328,29 @@ func (_c *ProfileRepository_Save_Call) RunAndReturn(run func(ctx context.Context } // Update provides a mock function for the type ProfileRepository -func (_mock *ProfileRepository) Update(ctx context.Context, p bootstrap.Profile) error { +func (_mock *ProfileRepository) Update(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error) { ret := _mock.Called(ctx, p) if len(ret) == 0 { panic("no return value specified for Update") } - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Profile) error); ok { + var r0 bootstrap.Profile + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Profile) (bootstrap.Profile, error)); ok { + return returnFunc(ctx, p) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Profile) bootstrap.Profile); ok { r0 = returnFunc(ctx, p) } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(bootstrap.Profile) } - return r0 + if returnFunc, ok := ret.Get(1).(func(context.Context, bootstrap.Profile) error); ok { + r1 = returnFunc(ctx, p) + } else { + r1 = ret.Error(1) + } + return r0, r1 } // ProfileRepository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' @@ -368,12 +383,12 @@ func (_c *ProfileRepository_Update_Call) Run(run func(ctx context.Context, p boo return _c } -func (_c *ProfileRepository_Update_Call) Return(err error) *ProfileRepository_Update_Call { - _c.Call.Return(err) +func (_c *ProfileRepository_Update_Call) Return(profile bootstrap.Profile, err error) *ProfileRepository_Update_Call { + _c.Call.Return(profile, err) return _c } -func (_c *ProfileRepository_Update_Call) RunAndReturn(run func(ctx context.Context, p bootstrap.Profile) error) *ProfileRepository_Update_Call { +func (_c *ProfileRepository_Update_Call) RunAndReturn(run func(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error)) *ProfileRepository_Update_Call { _c.Call.Return(run) return _c } diff --git a/bootstrap/mocks/service.go b/bootstrap/mocks/service.go index 00d4c4b3f..43ebaa760 100644 --- a/bootstrap/mocks/service.go +++ b/bootstrap/mocks/service.go @@ -781,8 +781,8 @@ func (_c *Service_ListBindings_Call) RunAndReturn(run func(ctx context.Context, } // ListProfiles provides a mock function for the type Service -func (_mock *Service) ListProfiles(ctx context.Context, session authn.Session, offset uint64, limit uint64) (bootstrap.ProfilesPage, error) { - ret := _mock.Called(ctx, session, offset, limit) +func (_mock *Service) ListProfiles(ctx context.Context, session authn.Session, offset uint64, limit uint64, name string) (bootstrap.ProfilesPage, error) { + ret := _mock.Called(ctx, session, offset, limit, name) if len(ret) == 0 { panic("no return value specified for ListProfiles") @@ -790,16 +790,16 @@ func (_mock *Service) ListProfiles(ctx context.Context, session authn.Session, o var r0 bootstrap.ProfilesPage var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, uint64, uint64) (bootstrap.ProfilesPage, error)); ok { - return returnFunc(ctx, session, offset, limit) + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, uint64, uint64, string) (bootstrap.ProfilesPage, error)); ok { + return returnFunc(ctx, session, offset, limit, name) } - if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, uint64, uint64) bootstrap.ProfilesPage); ok { - r0 = returnFunc(ctx, session, offset, limit) + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, uint64, uint64, string) bootstrap.ProfilesPage); ok { + r0 = returnFunc(ctx, session, offset, limit, name) } else { r0 = ret.Get(0).(bootstrap.ProfilesPage) } - if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, uint64, uint64) error); ok { - r1 = returnFunc(ctx, session, offset, limit) + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, uint64, uint64, string) error); ok { + r1 = returnFunc(ctx, session, offset, limit, name) } else { r1 = ret.Error(1) } @@ -816,11 +816,12 @@ type Service_ListProfiles_Call struct { // - session authn.Session // - offset uint64 // - limit uint64 -func (_e *Service_Expecter) ListProfiles(ctx interface{}, session interface{}, offset interface{}, limit interface{}) *Service_ListProfiles_Call { - return &Service_ListProfiles_Call{Call: _e.mock.On("ListProfiles", ctx, session, offset, limit)} +// - name string +func (_e *Service_Expecter) ListProfiles(ctx interface{}, session interface{}, offset interface{}, limit interface{}, name interface{}) *Service_ListProfiles_Call { + return &Service_ListProfiles_Call{Call: _e.mock.On("ListProfiles", ctx, session, offset, limit, name)} } -func (_c *Service_ListProfiles_Call) Run(run func(ctx context.Context, session authn.Session, offset uint64, limit uint64)) *Service_ListProfiles_Call { +func (_c *Service_ListProfiles_Call) Run(run func(ctx context.Context, session authn.Session, offset uint64, limit uint64, name string)) *Service_ListProfiles_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -838,11 +839,16 @@ func (_c *Service_ListProfiles_Call) Run(run func(ctx context.Context, session a if args[3] != nil { arg3 = args[3].(uint64) } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } run( arg0, arg1, arg2, arg3, + arg4, ) }) return _c @@ -853,7 +859,7 @@ func (_c *Service_ListProfiles_Call) Return(profilesPage bootstrap.ProfilesPage, return _c } -func (_c *Service_ListProfiles_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, offset uint64, limit uint64) (bootstrap.ProfilesPage, error)) *Service_ListProfiles_Call { +func (_c *Service_ListProfiles_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, offset uint64, limit uint64, name string) (bootstrap.ProfilesPage, error)) *Service_ListProfiles_Call { _c.Call.Return(run) return _c } @@ -1144,20 +1150,29 @@ func (_c *Service_UpdateCert_Call) RunAndReturn(run func(ctx context.Context, se } // UpdateProfile provides a mock function for the type Service -func (_mock *Service) UpdateProfile(ctx context.Context, session authn.Session, p bootstrap.Profile) error { +func (_mock *Service) UpdateProfile(ctx context.Context, session authn.Session, p bootstrap.Profile) (bootstrap.Profile, error) { ret := _mock.Called(ctx, session, p) if len(ret) == 0 { panic("no return value specified for UpdateProfile") } - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Profile) error); ok { + var r0 bootstrap.Profile + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Profile) (bootstrap.Profile, error)); ok { + return returnFunc(ctx, session, p) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Profile) bootstrap.Profile); ok { r0 = returnFunc(ctx, session, p) } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(bootstrap.Profile) } - return r0 + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, bootstrap.Profile) error); ok { + r1 = returnFunc(ctx, session, p) + } else { + r1 = ret.Error(1) + } + return r0, r1 } // Service_UpdateProfile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateProfile' @@ -1196,12 +1211,12 @@ func (_c *Service_UpdateProfile_Call) Run(run func(ctx context.Context, session return _c } -func (_c *Service_UpdateProfile_Call) Return(err error) *Service_UpdateProfile_Call { - _c.Call.Return(err) +func (_c *Service_UpdateProfile_Call) Return(profile bootstrap.Profile, err error) *Service_UpdateProfile_Call { + _c.Call.Return(profile, err) return _c } -func (_c *Service_UpdateProfile_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, p bootstrap.Profile) error) *Service_UpdateProfile_Call { +func (_c *Service_UpdateProfile_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, p bootstrap.Profile) (bootstrap.Profile, error)) *Service_UpdateProfile_Call { _c.Call.Return(run) return _c } diff --git a/bootstrap/postgres/configs_test.go b/bootstrap/postgres/configs_test.go index ff8019f26..064bcb95c 100644 --- a/bootstrap/postgres/configs_test.go +++ b/bootstrap/postgres/configs_test.go @@ -454,11 +454,11 @@ func TestAssignProfile(t *testing.T) { profileID := testsutil.GenerateUUID(t) _, err = profileRepo.Save(context.Background(), bootstrap.Profile{ - ID: profileID, - DomainID: c.DomainID, - Name: "edge-gateway", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, - Version: 1, + ID: profileID, + DomainID: c.DomainID, + Name: "edge-gateway", + ContentFormat: bootstrap.ContentFormatGoTemplate, + Version: 1, }) require.Nil(t, err, fmt.Sprintf("Saving profile expected to succeed: %s.\n", err)) diff --git a/bootstrap/postgres/init.go b/bootstrap/postgres/init.go index 47954c546..928e1dfc4 100644 --- a/bootstrap/postgres/init.go +++ b/bootstrap/postgres/init.go @@ -315,6 +315,15 @@ func Migration() *migrate.MemoryMigrationSource { `ALTER TABLE IF EXISTS profiles DROP COLUMN IF EXISTS binding_slots`, }, }, + { + Id: "configs_16", + Up: []string{ + `ALTER TABLE IF EXISTS profiles RENAME COLUMN template_format TO content_format`, + }, + Down: []string{ + `ALTER TABLE IF EXISTS profiles RENAME COLUMN content_format TO template_format`, + }, + }, }, } } diff --git a/bootstrap/postgres/profiles.go b/bootstrap/postgres/profiles.go index 3c201bc42..68be8def5 100644 --- a/bootstrap/postgres/profiles.go +++ b/bootstrap/postgres/profiles.go @@ -9,14 +9,13 @@ import ( "encoding/json" "fmt" "log/slog" + "strings" "time" "github.com/absmach/magistrala/bootstrap" "github.com/absmach/magistrala/pkg/errors" repoerr "github.com/absmach/magistrala/pkg/errors/repository" "github.com/absmach/magistrala/pkg/postgres" - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5/pgconn" ) var _ bootstrap.ProfileRepository = (*profileRepository)(nil) @@ -32,8 +31,8 @@ func NewProfileRepository(db postgres.Database, log *slog.Logger) bootstrap.Prof } func (pr profileRepository) Save(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error) { - q := `INSERT INTO profiles (id, domain_id, name, description, template_format, content_template, defaults, binding_slots, version, created_at, updated_at) - VALUES (:id, :domain_id, :name, :description, :template_format, :content_template, :defaults, :binding_slots, :version, :created_at, :updated_at)` + q := `INSERT INTO profiles (id, domain_id, name, description, content_format, content_template, defaults, binding_slots, version, created_at, updated_at) + VALUES (:id, :domain_id, :name, :description, :content_format, :content_template, :defaults, :binding_slots, :version, :created_at, :updated_at)` now := time.Now().UTC() p.CreatedAt = now @@ -45,35 +44,42 @@ func (pr profileRepository) Save(ctx context.Context, p bootstrap.Profile) (boot } if _, err = pr.db.NamedExecContext(ctx, q, dbp); err != nil { - if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == pgerrcode.UniqueViolation { - return bootstrap.Profile{}, repoerr.ErrConflict - } - return bootstrap.Profile{}, errors.Wrap(repoerr.ErrCreateEntity, err) + return bootstrap.Profile{}, postgres.HandleError(repoerr.ErrCreateEntity, err) } return p, nil } func (pr profileRepository) RetrieveByID(ctx context.Context, domainID, id string) (bootstrap.Profile, error) { - q := `SELECT id, domain_id, name, description, template_format, content_template, defaults, binding_slots, version, created_at, updated_at - FROM profiles WHERE id = $1 AND domain_id = $2` + q := `SELECT id, domain_id, name, description, content_format, content_template, defaults, binding_slots, version, created_at, updated_at + FROM profiles WHERE id = :id AND domain_id = :domain_id` + rows, err := pr.db.NamedQueryContext(ctx, q, dbProfile{ID: id, DomainID: domainID}) + if err != nil { + return bootstrap.Profile{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + if !rows.Next() { + return bootstrap.Profile{}, repoerr.ErrNotFound + } var dbp dbProfile - if err := pr.db.QueryRowxContext(ctx, q, id, domainID).StructScan(&dbp); err != nil { - if err == sql.ErrNoRows { - return bootstrap.Profile{}, repoerr.ErrNotFound - } + if err := rows.StructScan(&dbp); err != nil { return bootstrap.Profile{}, errors.Wrap(repoerr.ErrViewEntity, err) } return toProfile(dbp) } -func (pr profileRepository) RetrieveAll(ctx context.Context, domainID string, offset, limit uint64) (bootstrap.ProfilesPage, error) { - q := `SELECT id, domain_id, name, description, template_format, content_template, defaults, binding_slots, version, created_at, updated_at - FROM profiles WHERE domain_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3` +func (pr profileRepository) RetrieveAll(ctx context.Context, domainID string, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) { + dbPage := dbProfilesPage{DomainID: domainID, Offset: offset, Limit: limit, Name: name} + pageQuery := profilesPageQuery(dbPage) + q := fmt.Sprintf(`SELECT id, domain_id, name, description, content_format, content_template, defaults, binding_slots, version, created_at, updated_at + FROM profiles %s`, pageQuery) + q = applyProfilesOrdering(q) + q = fmt.Sprintf(`%s LIMIT :limit OFFSET :offset`, q) - rows, err := pr.db.QueryxContext(ctx, q, domainID, limit, offset) + rows, err := pr.db.NamedQueryContext(ctx, q, dbPage) if err != nil { return bootstrap.ProfilesPage{}, errors.Wrap(repoerr.ErrViewEntity, err) } @@ -93,8 +99,9 @@ func (pr profileRepository) RetrieveAll(ctx context.Context, domainID string, of profiles = append(profiles, p) } - var total uint64 - if err := pr.db.QueryRowxContext(ctx, `SELECT COUNT(*) FROM profiles WHERE domain_id = $1`, domainID).Scan(&total); err != nil { + cq := fmt.Sprintf(`SELECT COUNT(*) FROM profiles %s`, pageQuery) + total, err := postgres.Total(ctx, pr.db, cq, dbPage) + if err != nil { return bootstrap.ProfilesPage{}, errors.Wrap(repoerr.ErrViewEntity, err) } @@ -106,35 +113,82 @@ func (pr profileRepository) RetrieveAll(ctx context.Context, domainID string, of }, nil } -func (pr profileRepository) Update(ctx context.Context, p bootstrap.Profile) error { - q := `UPDATE profiles SET name = :name, description = :description, template_format = :template_format, - content_template = :content_template, defaults = :defaults, binding_slots = :binding_slots, version = version + 1, updated_at = :updated_at - WHERE id = :id AND domain_id = :domain_id` +type dbProfilesPage struct { + DomainID string `db:"domain_id"` + Offset uint64 `db:"offset"` + Limit uint64 `db:"limit"` + Name string `db:"name"` +} + +func profilesPageQuery(pm dbProfilesPage) string { + var query []string + query = append(query, "domain_id = :domain_id") + if pm.Name != "" { + query = append(query, "name ILIKE '%' || :name || '%'") + } + return fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) +} + +func applyProfilesOrdering(q string) string { + return fmt.Sprintf("%s ORDER BY created_at DESC", q) +} + +func (pr profileRepository) Update(ctx context.Context, p bootstrap.Profile) (bootstrap.Profile, error) { + var query []string + var upq string + if p.Name != "" { + query = append(query, "name = :name,") + } + if p.Description != "" { + query = append(query, "description = :description,") + } + if p.ContentFormat != "" { + query = append(query, "content_format = :content_format,") + } + if p.ContentTemplate != "" { + query = append(query, "content_template = :content_template,") + } + if p.Defaults != nil { + query = append(query, "defaults = :defaults,") + } + if p.BindingSlots != nil { + query = append(query, "binding_slots = :binding_slots,") + } + if len(query) > 0 { + upq = strings.Join(query, " ") + } + + q := fmt.Sprintf(`UPDATE profiles SET %s version = version + 1, updated_at = :updated_at + WHERE id = :id AND domain_id = :domain_id + RETURNING id, domain_id, name, description, content_format, content_template, defaults, binding_slots, version, created_at, updated_at`, + upq) p.UpdatedAt = time.Now().UTC() dbp, err := toDBProfile(p) if err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) + return bootstrap.Profile{}, errors.Wrap(repoerr.ErrUpdateEntity, err) } - res, err := pr.db.NamedExecContext(ctx, q, dbp) + rows, err := pr.db.NamedQueryContext(ctx, q, dbp) if err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) + return bootstrap.Profile{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) } - cnt, err := res.RowsAffected() - if err != nil { - return errors.Wrap(repoerr.ErrUpdateEntity, err) + defer rows.Close() + + if !rows.Next() { + return bootstrap.Profile{}, repoerr.ErrNotFound } - if cnt == 0 { - return repoerr.ErrNotFound + var updated dbProfile + if err := rows.StructScan(&updated); err != nil { + return bootstrap.Profile{}, errors.Wrap(repoerr.ErrUpdateEntity, err) } - return nil + return toProfile(updated) } func (pr profileRepository) Delete(ctx context.Context, domainID, id string) error { - q := `DELETE FROM profiles WHERE id = $1 AND domain_id = $2` - if _, err := pr.db.ExecContext(ctx, q, id, domainID); err != nil { + q := `DELETE FROM profiles WHERE id = :id AND domain_id = :domain_id` + if _, err := pr.db.NamedExecContext(ctx, q, dbProfile{ID: id, DomainID: domainID}); err != nil { return errors.Wrap(repoerr.ErrRemoveEntity, err) } return nil @@ -146,7 +200,7 @@ type dbProfile struct { DomainID string `db:"domain_id"` Name string `db:"name"` Description sql.NullString `db:"description"` - TemplateFormat string `db:"template_format"` + ContentFormat string `db:"content_format"` ContentTemplate sql.NullString `db:"content_template"` Defaults []byte `db:"defaults"` BindingSlots []byte `db:"binding_slots"` @@ -169,7 +223,7 @@ func toDBProfile(p bootstrap.Profile) (dbProfile, error) { DomainID: p.DomainID, Name: p.Name, Description: nullString(p.Description), - TemplateFormat: string(p.TemplateFormat), + ContentFormat: string(p.ContentFormat), ContentTemplate: nullString(p.ContentTemplate), Defaults: defaults, BindingSlots: bindingSlots, @@ -181,13 +235,13 @@ func toDBProfile(p bootstrap.Profile) (dbProfile, error) { func toProfile(dbp dbProfile) (bootstrap.Profile, error) { p := bootstrap.Profile{ - ID: dbp.ID, - DomainID: dbp.DomainID, - Name: dbp.Name, - TemplateFormat: bootstrap.TemplateFormat(dbp.TemplateFormat), - Version: dbp.Version, - CreatedAt: dbp.CreatedAt, - UpdatedAt: dbp.UpdatedAt, + ID: dbp.ID, + DomainID: dbp.DomainID, + Name: dbp.Name, + ContentFormat: bootstrap.ContentFormat(dbp.ContentFormat), + Version: dbp.Version, + CreatedAt: dbp.CreatedAt, + UpdatedAt: dbp.UpdatedAt, } if dbp.Description.Valid { p.Description = dbp.Description.String diff --git a/bootstrap/profiles.go b/bootstrap/profiles.go index 58d9cca6f..c4b2356c8 100644 --- a/bootstrap/profiles.go +++ b/bootstrap/profiles.go @@ -8,15 +8,15 @@ import ( "time" ) -// TemplateFormat enumerates supported content template formats. -type TemplateFormat string +// ContentFormat enumerates the supported output formats for rendered profile templates. +type ContentFormat string const ( - TemplateFormatGoTemplate TemplateFormat = "go-template" - TemplateFormatRaw TemplateFormat = "raw" - TemplateFormatJSON TemplateFormat = "json" - TemplateFormatYAML TemplateFormat = "yaml" - TemplateFormatTOML TemplateFormat = "toml" + ContentFormatGoTemplate ContentFormat = "go-template" + ContentFormatRaw ContentFormat = "raw" + ContentFormatJSON ContentFormat = "json" + ContentFormatYAML ContentFormat = "yaml" + ContentFormatTOML ContentFormat = "toml" ) // Profile is a user-managed device configuration template. @@ -25,7 +25,7 @@ type Profile struct { DomainID string `json:"domain_id,omitempty"` Name string `json:"name"` Description string `json:"description,omitempty"` - TemplateFormat TemplateFormat `json:"template_format"` + ContentFormat ContentFormat `json:"content_format"` ContentTemplate string `json:"content_template,omitempty"` Defaults map[string]any `json:"defaults,omitempty"` BindingSlots []BindingSlot `json:"binding_slots,omitempty"` @@ -58,11 +58,11 @@ type ProfileRepository interface { // RetrieveByID returns the Profile with the given ID inside the given domain. RetrieveByID(ctx context.Context, domainID, id string) (Profile, error) - // RetrieveAll returns a page of Profiles belonging to the given domain. - RetrieveAll(ctx context.Context, domainID string, offset, limit uint64) (ProfilesPage, error) + // RetrieveAll returns a page of Profiles belonging to the given domain, optionally filtered by name. + RetrieveAll(ctx context.Context, domainID string, offset, limit uint64, name string) (ProfilesPage, error) - // Update updates editable fields of the given Profile. - Update(ctx context.Context, p Profile) error + // Update updates editable fields of the given Profile and returns the updated Profile. + Update(ctx context.Context, p Profile) (Profile, error) // Delete removes the Profile with the given ID from the given domain. Delete(ctx context.Context, domainID, id string) error diff --git a/bootstrap/renderer.go b/bootstrap/renderer.go index 886df5cc1..5689885bb 100644 --- a/bootstrap/renderer.go +++ b/bootstrap/renderer.go @@ -34,13 +34,13 @@ func NewRenderer() Renderer { func (r renderer) Render(profile Profile, enrollment Config, bindings []BindingSnapshot) ([]byte, error) { rctx := buildRenderContext(profile, enrollment, bindings) - switch profile.TemplateFormat { - case TemplateFormatRaw: + switch profile.ContentFormat { + case ContentFormatRaw: return []byte(profile.ContentTemplate), nil - case TemplateFormatGoTemplate, TemplateFormatJSON, TemplateFormatYAML, TemplateFormatTOML, "": + case ContentFormatGoTemplate, ContentFormatJSON, ContentFormatYAML, ContentFormatTOML, "": return r.renderTemplate(profile, rctx) default: - return nil, fmt.Errorf("%w: unsupported template format %q", ErrRenderFailed, profile.TemplateFormat) + return nil, fmt.Errorf("%w: unsupported template format %q", ErrRenderFailed, profile.ContentFormat) } } @@ -58,38 +58,61 @@ func (r renderer) renderTemplate(profile Profile, rctx RenderContext) ([]byte, e return nil, fmt.Errorf("%w: %w", ErrRenderFailed, err) } - out, err := validateRenderedOutput(buf.Bytes(), profile.TemplateFormat) - if err != nil { - return nil, err - } - - return out, nil + return convertOutput(buf.Bytes(), profile.ContentFormat) } -// validateRenderedOutput checks that the rendered bytes are valid for the -// declared output format. It returns the original bytes on success and wraps -// ErrRenderFailed on failure. -func validateRenderedOutput(out []byte, format TemplateFormat) ([]byte, error) { - // Unrecognised formats are passed through. Recognised structured formats - // must parse successfully so broken templates fail before reaching devices. +// convertOutput parses the rendered bytes as any structured format (JSON, YAML, +// or TOML) and re-marshals them into the declared target format. For go-template +// or empty format the raw bytes are returned unchanged. +func convertOutput(out []byte, format ContentFormat) ([]byte, error) { switch format { - case TemplateFormatJSON: + case ContentFormatGoTemplate, "": + return out, nil + case ContentFormatJSON, ContentFormatYAML, ContentFormatTOML: var v any - if err := json.Unmarshal(out, &v); err != nil { - return nil, fmt.Errorf("%w: invalid json output: %w", ErrRenderFailed, err) + if err := parseStructured(out, &v); err != nil { + return nil, fmt.Errorf("%w: %w", ErrRenderFailed, err) } - case TemplateFormatYAML: - var v any - if err := yaml.Unmarshal(out, &v); err != nil { - return nil, fmt.Errorf("%w: invalid yaml output: %w", ErrRenderFailed, err) - } - case TemplateFormatTOML: - var v any - if err := toml.Unmarshal(out, &v); err != nil { - return nil, fmt.Errorf("%w: invalid toml output: %w", ErrRenderFailed, err) + result, err := marshalAs(v, format) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrRenderFailed, err) } + return result, nil + default: + return nil, fmt.Errorf("%w: unsupported format %q", ErrRenderFailed, format) + } +} + +// parseStructured tries JSON, then YAML, then TOML and unmarshals into v. +func parseStructured(out []byte, v any) error { + if err := json.Unmarshal(out, v); err == nil { + return nil + } + if err := yaml.Unmarshal(out, v); err == nil { + return nil + } + if err := toml.Unmarshal(out, v); err == nil { + return nil + } + return fmt.Errorf("template output is not valid JSON, YAML, or TOML") +} + +// marshalAs re-marshals v into the requested format. +func marshalAs(v any, format ContentFormat) ([]byte, error) { + switch format { + case ContentFormatJSON: + return json.MarshalIndent(v, "", " ") + case ContentFormatYAML: + return yaml.Marshal(v) + case ContentFormatTOML: + var buf bytes.Buffer + if err := toml.NewEncoder(&buf).Encode(v); err != nil { + return nil, err + } + return buf.Bytes(), nil + default: + return nil, fmt.Errorf("unsupported format %q", format) } - return out, nil } // buildRenderContext constructs the typed RenderContext from stored data. diff --git a/bootstrap/renderer_test.go b/bootstrap/renderer_test.go index 8aea4a2ff..b8d61ed6a 100644 --- a/bootstrap/renderer_test.go +++ b/bootstrap/renderer_test.go @@ -17,50 +17,65 @@ func TestRendererStructuredOutputValidation(t *testing.T) { cases := []struct { desc string - format bootstrap.TemplateFormat + format bootstrap.ContentFormat template string err error }{ { desc: "valid JSON output", - format: bootstrap.TemplateFormatJSON, + format: bootstrap.ContentFormatJSON, template: `{"device_id":"{{ .Device.ID }}"}`, }, { - desc: "invalid JSON output", - format: bootstrap.TemplateFormatJSON, - template: `{"device_id":`, + desc: "invalid output for JSON format", + format: bootstrap.ContentFormatJSON, + template: `[unclosed bracket`, err: bootstrap.ErrRenderFailed, }, { desc: "valid YAML output", - format: bootstrap.TemplateFormatYAML, + format: bootstrap.ContentFormatYAML, template: "device_id: {{ .Device.ID }}", }, { - desc: "invalid YAML output", - format: bootstrap.TemplateFormatYAML, - template: "device_id: [", + desc: "invalid output for YAML format", + format: bootstrap.ContentFormatYAML, + template: "[unclosed bracket", err: bootstrap.ErrRenderFailed, }, { desc: "valid TOML output", - format: bootstrap.TemplateFormatTOML, + format: bootstrap.ContentFormatTOML, template: `device_id = "{{ .Device.ID }}"`, }, { - desc: "invalid TOML output", - format: bootstrap.TemplateFormatTOML, - template: `device_id = `, + desc: "invalid output for TOML format", + format: bootstrap.ContentFormatTOML, + template: `[unclosed bracket`, err: bootstrap.ErrRenderFailed, }, + { + desc: "JSON template auto-converted to TOML", + format: bootstrap.ContentFormatTOML, + template: `{"device_id":"{{ .Device.ID }}"}`, + }, + { + desc: "TOML template auto-converted to JSON", + format: bootstrap.ContentFormatJSON, + template: `device_id = "{{ .Device.ID }}"`, + }, + { + desc: "YAML template auto-converted to TOML", + format: bootstrap.ContentFormatTOML, + template: "device_id: {{ .Device.ID }}", + }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { _, err := renderer.Render( bootstrap.Profile{ - TemplateFormat: tc.format, + ContentFormat: tc.format, ContentTemplate: tc.template, }, bootstrap.Config{ID: "config-id"}, diff --git a/bootstrap/service.go b/bootstrap/service.go index 77c806898..a66f2d164 100644 --- a/bootstrap/service.go +++ b/bootstrap/service.go @@ -90,11 +90,11 @@ type Service interface { // ViewProfile returns the Profile with the given ID. ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (Profile, error) - // UpdateProfile updates editable fields of the given Profile. - UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) error + // UpdateProfile updates editable fields of the given Profile and returns the updated Profile. + UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error) // ListProfiles returns a page of Profiles belonging to the domain. - ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (ProfilesPage, error) + ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (ProfilesPage, error) // DeleteProfile removes the Profile with the given ID. DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error @@ -328,8 +328,8 @@ func (bs bootstrapService) CreateProfile(ctx context.Context, session smqauthn.S } p.ID = id p.DomainID = session.DomainID - if p.TemplateFormat == "" { - p.TemplateFormat = TemplateFormatGoTemplate + if p.ContentFormat == "" { + p.ContentFormat = ContentFormatJSON } p.Version = 1 if err := validateProfileBindingSlots(p); err != nil { @@ -356,31 +356,29 @@ func (bs bootstrapService) ViewProfile(ctx context.Context, session smqauthn.Ses return p, nil } -func (bs bootstrapService) UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) error { +func (bs bootstrapService) UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error) { if bs.profiles == nil { - return errors.Wrap(errUpdateProfile, errors.New("profile repository not configured")) + return Profile{}, errors.Wrap(errUpdateProfile, errors.New("profile repository not configured")) } p.DomainID = session.DomainID - if p.TemplateFormat == "" { - p.TemplateFormat = TemplateFormatGoTemplate - } if err := validateProfileBindingSlots(p); err != nil { - return errors.Wrap(errUpdateProfile, err) + return Profile{}, errors.Wrap(errUpdateProfile, err) } if err := validateProfileTemplate(p); err != nil { - return errors.Wrap(errUpdateProfile, err) + return Profile{}, errors.Wrap(errUpdateProfile, err) } - if err := bs.profiles.Update(ctx, p); err != nil { - return errors.Wrap(errUpdateProfile, err) + updated, err := bs.profiles.Update(ctx, p) + if err != nil { + return Profile{}, errors.Wrap(errUpdateProfile, err) } - return nil + return updated, nil } -func (bs bootstrapService) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (ProfilesPage, error) { +func (bs bootstrapService) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (ProfilesPage, error) { if bs.profiles == nil { return ProfilesPage{}, errors.Wrap(errListProfiles, errors.New("profile repository not configured")) } - page, err := bs.profiles.RetrieveAll(ctx, session.DomainID, offset, limit) + page, err := bs.profiles.RetrieveAll(ctx, session.DomainID, offset, limit, name) if err != nil { return ProfilesPage{}, errors.Wrap(errListProfiles, err) } diff --git a/bootstrap/service_test.go b/bootstrap/service_test.go index a4ab54414..d99986e7b 100644 --- a/bootstrap/service_test.go +++ b/bootstrap/service_test.go @@ -735,7 +735,7 @@ func TestBootstrapRender(t *testing.T) { ID: testsutil.GenerateUUID(&testing.T{}), DomainID: domainID, Name: "gateway-profile", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, + ContentFormat: bootstrap.ContentFormatGoTemplate, ContentTemplate: `{"mode":"profile"}`, } bindings := []bootstrap.BindingSnapshot{ @@ -968,11 +968,11 @@ func TestDisableConfig(t *testing.T) { func TestAssignProfile(t *testing.T) { profile := bootstrap.Profile{ - ID: testsutil.GenerateUUID(t), - DomainID: domainID, - Name: "gateway-profile", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, - Version: 1, + ID: testsutil.GenerateUUID(t), + DomainID: domainID, + Name: "gateway-profile", + ContentFormat: bootstrap.ContentFormatGoTemplate, + Version: 1, } cases := []struct { @@ -1034,23 +1034,26 @@ func TestCreateProfile(t *testing.T) { session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID} validProfile := bootstrap.Profile{ - Name: "test-profile", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, + Name: "test-profile", + ContentFormat: bootstrap.ContentFormatGoTemplate, } cases := []struct { - desc string - profile bootstrap.Profile - saveErr error - err error + desc string + profile bootstrap.Profile + saveErr error + err error + wantFormat bootstrap.ContentFormat }{ { - desc: "create profile successfully", - profile: validProfile, + desc: "create profile successfully", + profile: validProfile, + wantFormat: bootstrap.ContentFormatGoTemplate, }, { - desc: "create profile defaults to go-template format", - profile: bootstrap.Profile{Name: "no-format"}, + desc: "create profile defaults to json format", + profile: bootstrap.Profile{Name: "no-format"}, + wantFormat: bootstrap.ContentFormatJSON, }, { desc: "create profile with invalid slot: empty name", @@ -1107,7 +1110,7 @@ func TestCreateProfile(t *testing.T) { if tc.err == nil { assert.NotEmpty(t, saved.ID, fmt.Sprintf("%s: expected non-empty profile ID\n", tc.desc)) assert.Equal(t, domainID, saved.DomainID, fmt.Sprintf("%s: expected domain ID %s got %s\n", tc.desc, domainID, saved.DomainID)) - assert.Equal(t, bootstrap.TemplateFormatGoTemplate, saved.TemplateFormat, fmt.Sprintf("%s: expected go-template format\n", tc.desc)) + assert.Equal(t, tc.wantFormat, saved.ContentFormat, fmt.Sprintf("%s: expected %s format\n", tc.desc, tc.wantFormat)) assert.Equal(t, 1, saved.Version, fmt.Sprintf("%s: expected version 1\n", tc.desc)) } saveCall.Unset() @@ -1119,11 +1122,11 @@ func TestViewProfile(t *testing.T) { session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID} profile := bootstrap.Profile{ - ID: testsutil.GenerateUUID(t), - DomainID: domainID, - Name: "view-profile", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, - Version: 1, + ID: testsutil.GenerateUUID(t), + DomainID: domainID, + Name: "view-profile", + ContentFormat: bootstrap.ContentFormatGoTemplate, + Version: 1, } cases := []struct { @@ -1162,10 +1165,10 @@ func TestUpdateProfile(t *testing.T) { session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID} validProfile := bootstrap.Profile{ - ID: testsutil.GenerateUUID(t), - DomainID: domainID, - Name: "updated-profile", - TemplateFormat: bootstrap.TemplateFormatGoTemplate, + ID: testsutil.GenerateUUID(t), + DomainID: domainID, + Name: "updated-profile", + ContentFormat: bootstrap.ContentFormatGoTemplate, } cases := []struct { @@ -1179,7 +1182,7 @@ func TestUpdateProfile(t *testing.T) { profile: validProfile, }, { - desc: "update profile defaults to go-template format", + desc: "update profile with only name", profile: bootstrap.Profile{ID: validProfile.ID, Name: "no-format"}, }, { @@ -1223,8 +1226,8 @@ func TestUpdateProfile(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svc := newService() - updateCall := profileRepo.On("Update", context.Background(), mock.Anything).Return(tc.updateErr) - err := svc.UpdateProfile(context.Background(), session, tc.profile) + updateCall := profileRepo.On("Update", context.Background(), mock.Anything).Return(tc.profile, tc.updateErr) + _, err := svc.UpdateProfile(context.Background(), session, tc.profile) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) updateCall.Unset() }) @@ -1235,15 +1238,17 @@ func TestListProfiles(t *testing.T) { session := smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID} profiles := []bootstrap.Profile{ - {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p1", TemplateFormat: bootstrap.TemplateFormatGoTemplate, Version: 1}, - {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p2", TemplateFormat: bootstrap.TemplateFormatGoTemplate, Version: 1}, + {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p1", ContentFormat: bootstrap.ContentFormatGoTemplate, Version: 1}, + {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p2", ContentFormat: bootstrap.ContentFormatGoTemplate, Version: 1}, } page := bootstrap.ProfilesPage{Total: 2, Offset: 0, Limit: 10, Profiles: profiles} + filteredPage := bootstrap.ProfilesPage{Total: 1, Offset: 0, Limit: 10, Profiles: profiles[:1]} cases := []struct { desc string offset uint64 limit uint64 + name string page bootstrap.ProfilesPage listErr error err error @@ -1253,6 +1258,12 @@ func TestListProfiles(t *testing.T) { limit: 10, page: page, }, + { + desc: "list profiles filtered by name", + limit: 10, + name: "p1", + page: filteredPage, + }, { desc: "list profiles with repository error", limit: 10, @@ -1264,8 +1275,8 @@ func TestListProfiles(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svc := newService() - listCall := profileRepo.On("RetrieveAll", context.Background(), domainID, tc.offset, tc.limit).Return(tc.page, tc.listErr) - got, err := svc.ListProfiles(context.Background(), session, tc.offset, tc.limit) + listCall := profileRepo.On("RetrieveAll", context.Background(), domainID, tc.offset, tc.limit, tc.name).Return(tc.page, tc.listErr) + got, err := svc.ListProfiles(context.Background(), session, tc.offset, tc.limit, tc.name) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) if tc.err == nil { assert.Equal(t, tc.page, got, fmt.Sprintf("%s: expected page %v got %v\n", tc.desc, tc.page, got)) diff --git a/bootstrap/tracing/tracing.go b/bootstrap/tracing/tracing.go index c1008f60e..c26207542 100644 --- a/bootstrap/tracing/tracing.go +++ b/bootstrap/tracing/tracing.go @@ -139,7 +139,7 @@ func (tm *tracingMiddleware) ViewProfile(ctx context.Context, session smqauthn.S return tm.svc.ViewProfile(ctx, session, profileID) } -func (tm *tracingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) error { +func (tm *tracingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) { ctx, span := tm.tracer.Start(ctx, "svc_update_profile", trace.WithAttributes( attribute.String("profile_id", p.ID), )) @@ -147,13 +147,13 @@ func (tm *tracingMiddleware) UpdateProfile(ctx context.Context, session smqauthn return tm.svc.UpdateProfile(ctx, session, p) } -func (tm *tracingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64) (bootstrap.ProfilesPage, error) { +func (tm *tracingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) { ctx, span := tm.tracer.Start(ctx, "svc_list_profiles", trace.WithAttributes( attribute.Int64("offset", int64(offset)), attribute.Int64("limit", int64(limit)), )) defer span.End() - return tm.svc.ListProfiles(ctx, session, offset, limit) + return tm.svc.ListProfiles(ctx, session, offset, limit, name) } func (tm *tracingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error { diff --git a/cli/bootstrap.go b/cli/bootstrap.go index cd3d37a0a..ad9a2a1aa 100644 --- a/cli/bootstrap.go +++ b/cli/bootstrap.go @@ -272,12 +272,13 @@ var cmdBootstrap = []cobra.Command{ return } - if err := sdk.UpdateBootstrapProfile(cmd.Context(), profile, args[2], args[3]); err != nil { + updated, err := sdk.UpdateBootstrapProfile(cmd.Context(), profile, args[2], args[3]) + if err != nil { logErrorCmd(*cmd, err) return } - logOKCmd(*cmd) + logJSONCmd(*cmd, updated) case "remove": if len(args) != 4 { logUsageCmd(*cmd, cmd.Use) diff --git a/cli/bootstrap_test.go b/cli/bootstrap_test.go index 1172ad2e7..57fbcee77 100644 --- a/cli/bootstrap_test.go +++ b/cli/bootstrap_test.go @@ -35,7 +35,7 @@ var ( ID: profileID, Name: "Test Profile", Description: "Test profile", - TemplateFormat: "go-template", + ContentFormat: "go-template", ContentTemplate: "{\"device_id\":\"{{ .Device.ID }}\"}", Version: 1, } @@ -707,7 +707,8 @@ func TestBootstrapProfilesCmd(t *testing.T) { domainID, validToken, }, - logType: okLog, + profile: bootProfile, + logType: entityLog, }, { desc: "remove bootstrap profile successfully", @@ -729,7 +730,7 @@ func TestBootstrapProfilesCmd(t *testing.T) { createCall := sdkMock.On("CreateBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.profile, tc.sdkErr) listCall := sdkMock.On("BootstrapProfiles", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr) viewCall := sdkMock.On("ViewBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.profile, tc.sdkErr) - updateCall := sdkMock.On("UpdateBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + updateCall := sdkMock.On("UpdateBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.profile, tc.sdkErr) removeCall := sdkMock.On("RemoveBootstrapProfile", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{"profiles"}, tc.args...)...) diff --git a/pkg/sdk/bootstrap.go b/pkg/sdk/bootstrap.go index 5e895125c..1a75729c6 100644 --- a/pkg/sdk/bootstrap.go +++ b/pkg/sdk/bootstrap.go @@ -109,7 +109,7 @@ type BootstrapProfile struct { DomainID string `json:"domain_id,omitempty"` Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` - TemplateFormat string `json:"template_format,omitempty"` + ContentFormat string `json:"content_format,omitempty"` ContentTemplate string `json:"content_template,omitempty"` Defaults map[string]any `json:"defaults,omitempty"` BindingSlots []BindingSlot `json:"binding_slots,omitempty"` @@ -302,19 +302,28 @@ func (sdk mgSDK) UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domai return sdkerr } -func (sdk mgSDK) UpdateBootstrapProfile(ctx context.Context, profile BootstrapProfile, domainID, token string) errors.SDKError { +func (sdk mgSDK) UpdateBootstrapProfile(ctx context.Context, profile BootstrapProfile, domainID, token string) (BootstrapProfile, errors.SDKError) { if profile.ID == "" { - return errors.NewSDKError(apiutil.ErrMissingID) + return BootstrapProfile{}, errors.NewSDKError(apiutil.ErrMissingID) } url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, bootstrapProfilesPath, profile.ID) data, err := json.Marshal(profile) if err != nil { - return errors.NewSDKError(err) + return BootstrapProfile{}, errors.NewSDKError(err) } - _, _, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) - return sdkerr + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) + if sdkerr != nil { + return BootstrapProfile{}, sdkerr + } + + var updated BootstrapProfile + if err := json.Unmarshal(body, &updated); err != nil { + return BootstrapProfile{}, errors.NewSDKError(err) + } + + return updated, nil } func (sdk mgSDK) UpdateBootstrapCerts(ctx context.Context, id, clientCert, clientKey, ca, domainID, token string) (BootstrapConfig, errors.SDKError) { diff --git a/pkg/sdk/bootstrap_test.go b/pkg/sdk/bootstrap_test.go index bbec3178e..bb85eca18 100644 --- a/pkg/sdk/bootstrap_test.go +++ b/pkg/sdk/bootstrap_test.go @@ -908,6 +908,312 @@ func TestBootstrapSecure(t *testing.T) { } } +func TestCreateBootstrapProfile(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + mgsdk := sdk.NewSDK(sdk.Config{BootstrapURL: bs.URL}) + + profile := sdk.BootstrapProfile{ + Name: "gateway-profile", + ContentFormat: "go-template", + } + saved := bootstrap.Profile{ + ID: testsutil.GenerateUUID(t), + DomainID: domainID, + Name: "gateway-profile", + ContentFormat: bootstrap.ContentFormatGoTemplate, + } + + cases := []struct { + desc string + token string + profile sdk.BootstrapProfile + svcResp bootstrap.Profile + svcErr error + authErr error + expectedSDKErr errors.SDKError + }{ + { + desc: "create profile successfully", + token: validToken, + profile: profile, + svcResp: saved, + }, + { + desc: "create profile with invalid token", + token: invalidToken, + profile: profile, + authErr: svcerr.ErrAuthentication, + expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "create profile with empty name", + token: validToken, + profile: sdk.BootstrapProfile{}, + expectedSDKErr: errors.NewSDKErrorWithStatus(apiutil.ErrMissingName, http.StatusBadRequest), + }, + { + desc: "create profile with service error", + token: validToken, + profile: profile, + svcErr: svcerr.ErrCreateEntity, + expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrCreateEntity, http.StatusUnprocessableEntity), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + session := smqauthn.Session{} + if tc.token == validToken { + session = bootstrapSession() + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(session, tc.authErr) + svcCall := bsvc.On("CreateProfile", mock.Anything, session, mock.Anything).Return(tc.svcResp, tc.svcErr) + + _, err := mgsdk.CreateBootstrapProfile(context.Background(), tc.profile, domainID, tc.token) + assert.Equal(t, tc.expectedSDKErr, err) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewBootstrapProfile(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + mgsdk := sdk.NewSDK(sdk.Config{BootstrapURL: bs.URL}) + + profileID := testsutil.GenerateUUID(t) + saved := bootstrap.Profile{ + ID: profileID, + DomainID: domainID, + Name: "gateway-profile", + ContentFormat: bootstrap.ContentFormatGoTemplate, + } + expected := sdk.BootstrapProfile{ + ID: profileID, + DomainID: domainID, + Name: "gateway-profile", + ContentFormat: "go-template", + } + + cases := []struct { + desc string + token string + profileID string + svcResp bootstrap.Profile + svcErr error + authErr error + expectedResp sdk.BootstrapProfile + expectedSDKErr errors.SDKError + }{ + { + desc: "view profile successfully", + token: validToken, + profileID: profileID, + svcResp: saved, + expectedResp: expected, + }, + { + desc: "view profile with invalid token", + token: invalidToken, + profileID: profileID, + authErr: svcerr.ErrAuthentication, + expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "view profile with empty id", + token: validToken, + profileID: "", + expectedSDKErr: errors.NewSDKError(apiutil.ErrMissingID), + }, + { + desc: "view profile not found", + token: validToken, + profileID: profileID, + svcErr: svcerr.ErrNotFound, + expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + session := smqauthn.Session{} + if tc.token == validToken { + session = bootstrapSession() + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(session, tc.authErr) + svcCall := bsvc.On("ViewProfile", mock.Anything, session, tc.profileID).Return(tc.svcResp, tc.svcErr) + + resp, err := mgsdk.ViewBootstrapProfile(context.Background(), tc.profileID, domainID, tc.token) + assert.Equal(t, tc.expectedSDKErr, err) + if tc.expectedSDKErr == nil { + assert.Equal(t, tc.expectedResp, resp) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateBootstrapProfile(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + mgsdk := sdk.NewSDK(sdk.Config{BootstrapURL: bs.URL}) + + profileID := testsutil.GenerateUUID(t) + updatedProfile := bootstrap.Profile{ + ID: profileID, + DomainID: domainID, + Name: "updated-name", + ContentFormat: bootstrap.ContentFormatYAML, + } + expectedResp := sdk.BootstrapProfile{ + ID: profileID, + DomainID: domainID, + Name: "updated-name", + ContentFormat: "yaml", + } + + cases := []struct { + desc string + token string + profile sdk.BootstrapProfile + svcResp bootstrap.Profile + svcErr error + authErr error + expectedResp sdk.BootstrapProfile + expectedSDKErr errors.SDKError + }{ + { + desc: "update profile successfully", + token: validToken, + profile: sdk.BootstrapProfile{ + ID: profileID, + Name: "updated-name", + ContentFormat: "yaml", + }, + svcResp: updatedProfile, + expectedResp: expectedResp, + }, + { + desc: "update profile with invalid token", + token: invalidToken, + profile: sdk.BootstrapProfile{ + ID: profileID, + }, + authErr: svcerr.ErrAuthentication, + expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "update profile with empty id", + token: validToken, + profile: sdk.BootstrapProfile{}, + expectedSDKErr: errors.NewSDKError(apiutil.ErrMissingID), + }, + { + desc: "update profile not found", + token: validToken, + profile: sdk.BootstrapProfile{ + ID: profileID, + }, + svcErr: svcerr.ErrNotFound, + expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + session := smqauthn.Session{} + if tc.token == validToken { + session = bootstrapSession() + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(session, tc.authErr) + svcCall := bsvc.On("UpdateProfile", mock.Anything, session, mock.Anything).Return(tc.svcResp, tc.svcErr) + + resp, err := mgsdk.UpdateBootstrapProfile(context.Background(), tc.profile, domainID, tc.token) + assert.Equal(t, tc.expectedSDKErr, err) + if tc.expectedSDKErr == nil { + assert.Equal(t, tc.expectedResp, resp) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestBootstrapProfiles(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + mgsdk := sdk.NewSDK(sdk.Config{BootstrapURL: bs.URL}) + + profiles := bootstrap.ProfilesPage{ + Total: 2, + Offset: 0, + Limit: 10, + Profiles: []bootstrap.Profile{ + {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p1", ContentFormat: bootstrap.ContentFormatGoTemplate}, + {ID: testsutil.GenerateUUID(t), DomainID: domainID, Name: "p2", ContentFormat: bootstrap.ContentFormatYAML}, + }, + } + + cases := []struct { + desc string + token string + pageMeta sdk.PageMetadata + svcResp bootstrap.ProfilesPage + svcErr error + authErr error + expectedCount int + expectedSDKErr errors.SDKError + }{ + { + desc: "list profiles successfully", + token: validToken, + pageMeta: sdk.PageMetadata{Offset: 0, Limit: 10}, + svcResp: profiles, + expectedCount: 2, + }, + { + desc: "list profiles filtered by name", + token: validToken, + pageMeta: sdk.PageMetadata{Offset: 0, Limit: 10, Name: "p1"}, + svcResp: bootstrap.ProfilesPage{Total: 1, Profiles: profiles.Profiles[:1]}, + expectedCount: 1, + }, + { + desc: "list profiles with invalid token", + token: invalidToken, + pageMeta: sdk.PageMetadata{Offset: 0, Limit: 10}, + authErr: svcerr.ErrAuthentication, + expectedSDKErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + session := smqauthn.Session{} + if tc.token == validToken { + session = bootstrapSession() + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(session, tc.authErr) + svcCall := bsvc.On("ListProfiles", mock.Anything, session, mock.Anything, mock.Anything, mock.Anything).Return(tc.svcResp, tc.svcErr) + + resp, err := mgsdk.BootstrapProfiles(context.Background(), tc.pageMeta, domainID, tc.token) + assert.Equal(t, tc.expectedSDKErr, err) + if tc.expectedSDKErr == nil { + assert.Equal(t, tc.expectedCount, len(resp.Profiles)) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + func encrypt(in, encKey []byte) ([]byte, error) { block, err := aes.NewCipher(encKey) if err != nil { diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index fc253f3ad..382dcd25f 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -11793,22 +11793,31 @@ func (_c *SDK_UpdateBootstrapConnection_Call) RunAndReturn(run func(ctx context. } // UpdateBootstrapProfile provides a mock function for the type SDK -func (_mock *SDK) UpdateBootstrapProfile(ctx context.Context, profile sdk.BootstrapProfile, domainID string, token string) errors.SDKError { +func (_mock *SDK) UpdateBootstrapProfile(ctx context.Context, profile sdk.BootstrapProfile, domainID string, token string) (sdk.BootstrapProfile, errors.SDKError) { ret := _mock.Called(ctx, profile, domainID, token) if len(ret) == 0 { panic("no return value specified for UpdateBootstrapProfile") } - var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapProfile, string, string) errors.SDKError); ok { + var r0 sdk.BootstrapProfile + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapProfile, string, string) (sdk.BootstrapProfile, errors.SDKError)); ok { + return returnFunc(ctx, profile, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapProfile, string, string) sdk.BootstrapProfile); ok { r0 = returnFunc(ctx, profile, domainID, token) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(errors.SDKError) + r0 = ret.Get(0).(sdk.BootstrapProfile) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.BootstrapProfile, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, profile, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) } } - return r0 + return r0, r1 } // SDK_UpdateBootstrapProfile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBootstrapProfile' @@ -11853,12 +11862,12 @@ func (_c *SDK_UpdateBootstrapProfile_Call) Run(run func(ctx context.Context, pro return _c } -func (_c *SDK_UpdateBootstrapProfile_Call) Return(sDKError errors.SDKError) *SDK_UpdateBootstrapProfile_Call { - _c.Call.Return(sDKError) +func (_c *SDK_UpdateBootstrapProfile_Call) Return(bootstrapProfile sdk.BootstrapProfile, sDKError errors.SDKError) *SDK_UpdateBootstrapProfile_Call { + _c.Call.Return(bootstrapProfile, sDKError) return _c } -func (_c *SDK_UpdateBootstrapProfile_Call) RunAndReturn(run func(ctx context.Context, profile sdk.BootstrapProfile, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrapProfile_Call { +func (_c *SDK_UpdateBootstrapProfile_Call) RunAndReturn(run func(ctx context.Context, profile sdk.BootstrapProfile, domainID string, token string) (sdk.BootstrapProfile, errors.SDKError)) *SDK_UpdateBootstrapProfile_Call { _c.Call.Return(run) return _c } diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index 9c638721f..544b50b1d 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -1631,8 +1631,8 @@ type SDK interface { // UpdateBootstrap updates editable fields of the provided Config. UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) smqerrors.SDKError - // UpdateBootstrapProfile updates editable fields of the provided bootstrap profile. - UpdateBootstrapProfile(ctx context.Context, profile BootstrapProfile, domainID, token string) smqerrors.SDKError + // UpdateBootstrapProfile updates editable fields of the provided bootstrap profile and returns the updated profile. + UpdateBootstrapProfile(ctx context.Context, profile BootstrapProfile, domainID, token string) (BootstrapProfile, smqerrors.SDKError) // UpdateBootstrapCerts updates bootstrap config certificates. UpdateBootstrapCerts(ctx context.Context, id string, clientCert, clientKey, ca string, domainID, token string) (BootstrapConfig, smqerrors.SDKError)