mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
NOISSUE - Add Nullable type for optional values handling (#2877)
Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/channels"
|
||||
"github.com/absmach/supermq/internal/nullable"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
@@ -130,17 +131,11 @@ func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
var groupPtr *string
|
||||
groupID, err := apiutil.ReadStringQuery(r, api.GroupKey, "")
|
||||
groupID, err := nullable.Parse(r.URL.Query(), api.GroupKey, nullable.ParseString)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
if r.URL.Query().Has(api.GroupKey) {
|
||||
groupPtr = &groupID
|
||||
}
|
||||
|
||||
clientID, err := apiutil.ReadStringQuery(r, api.ClientKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
@@ -165,7 +160,7 @@ func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error)
|
||||
Dir: dir,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
Group: groupPtr,
|
||||
Group: groupID,
|
||||
Client: clientID,
|
||||
ID: id,
|
||||
},
|
||||
|
||||
+20
-19
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/internal/nullable"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
@@ -45,25 +46,25 @@ type Channel struct {
|
||||
}
|
||||
|
||||
type Page struct {
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Order string `json:"order,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Status Status `json:"status,omitempty"`
|
||||
Group *string `json:"group,omitempty"`
|
||||
Client string `json:"client,omitempty"`
|
||||
ConnectionType string `json:"connection_type,omitempty"`
|
||||
RoleName string `json:"role_name,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
Actions []string `json:"actions,omitempty"`
|
||||
AccessType string `json:"access_type,omitempty"`
|
||||
IDs []string `json:"-"`
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Order string `json:"order,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Status Status `json:"status,omitempty"`
|
||||
Group nullable.Value[string] `json:"group,omitempty"`
|
||||
Client string `json:"client,omitempty"`
|
||||
ConnectionType string `json:"connection_type,omitempty"`
|
||||
RoleName string `json:"role_name,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
Actions []string `json:"actions,omitempty"`
|
||||
AccessType string `json:"access_type,omitempty"`
|
||||
IDs []string `json:"-"`
|
||||
}
|
||||
|
||||
// ChannelsPage contains page related metadata as well as list of channels that
|
||||
|
||||
@@ -1252,13 +1252,12 @@ func PageQuery(pm channels.Page) (string, error) {
|
||||
if pm.Domain != "" {
|
||||
query = append(query, "c.domain_id = :domain_id")
|
||||
}
|
||||
|
||||
if pm.Group != nil {
|
||||
switch *pm.Group {
|
||||
case "":
|
||||
query = append(query, "c.parent_group_id = '' ")
|
||||
default:
|
||||
if pm.Group.Set {
|
||||
switch {
|
||||
case pm.Group.Value != "":
|
||||
query = append(query, "c.parent_group_path <@ (SELECT path from groups where id = :group_id) ")
|
||||
default:
|
||||
query = append(query, "c.parent_group_id = '' ")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1321,7 +1320,7 @@ func toDBChannelsPage(pm channels.Page) (dbChannelsPage, error) {
|
||||
Metadata: data,
|
||||
Tag: pm.Tag,
|
||||
Status: pm.Status,
|
||||
GroupID: pm.Group,
|
||||
GroupID: sql.NullString{Valid: pm.Group.Set, String: pm.Group.Value},
|
||||
ClientID: pm.Client,
|
||||
ConnType: pm.ConnectionType,
|
||||
RoleName: pm.RoleName,
|
||||
@@ -1340,7 +1339,7 @@ type dbChannelsPage struct {
|
||||
Metadata []byte `db:"metadata"`
|
||||
Tag string `db:"tag"`
|
||||
Status channels.Status `db:"status"`
|
||||
GroupID *string `db:"group_id"`
|
||||
GroupID sql.NullString `db:"group_id"`
|
||||
ClientID string `db:"client_id"`
|
||||
ConnType string `db:"type"`
|
||||
RoleName string `db:"role_name"`
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nullable
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkPointerString(b *testing.B) {
|
||||
p := func() *string {
|
||||
x := "test"
|
||||
return &x
|
||||
}()
|
||||
for b.Loop() {
|
||||
if p != nil {
|
||||
_ = *p + "test"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullableString(b *testing.B) {
|
||||
n := func() Value[string] {
|
||||
return Value[string]{Set: true, Value: "test"}
|
||||
}()
|
||||
|
||||
for b.Loop() {
|
||||
if n.Set {
|
||||
_ = n.Value + "test"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullableStringParse(b *testing.B) {
|
||||
for b.Loop() {
|
||||
n, _ := ParseString("123")
|
||||
_ = n.Value
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPointerStringParse(b *testing.B) {
|
||||
parser := func(s string) (*string, error) {
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
n, _ := parser("123")
|
||||
_ = *n
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPointerInt(b *testing.B) {
|
||||
p := func() *int {
|
||||
x := 42
|
||||
return &x
|
||||
}()
|
||||
|
||||
for b.Loop() {
|
||||
if p != nil {
|
||||
_ = *p + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullableInt(b *testing.B) {
|
||||
n := func() Value[int] {
|
||||
return Value[int]{Set: true, Value: 42}
|
||||
}()
|
||||
|
||||
for b.Loop() {
|
||||
if n.Set {
|
||||
_ = n.Value + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullableIntParse(b *testing.B) {
|
||||
for b.Loop() {
|
||||
n, _ := ParseInt("123")
|
||||
_ = n.Value
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPointerIntParse(b *testing.B) {
|
||||
parser := func(s string) (*int, error) {
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
n, _ := parser("123")
|
||||
_ = *n
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPointerFloat(b *testing.B) {
|
||||
p := func() *float64 {
|
||||
x := float64(42)
|
||||
return &x
|
||||
}()
|
||||
|
||||
for b.Loop() {
|
||||
if p != nil {
|
||||
_ = *p + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullableFloat(b *testing.B) {
|
||||
n := func() Value[float64] {
|
||||
return Value[float64]{Set: true, Value: 42}
|
||||
}()
|
||||
|
||||
for b.Loop() {
|
||||
if n.Set {
|
||||
_ = n.Value + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNullableFloatParse(b *testing.B) {
|
||||
for b.Loop() {
|
||||
n, _ := ParseFloat("123.45")
|
||||
_ = n.Value
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPointerFloatParse(b *testing.B) {
|
||||
parser := func(s string) (*float64, error) {
|
||||
v, err := strconv.ParseFloat(s, 10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
n, _ := parser("123.45")
|
||||
_ = *n
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseNullable(b *testing.B) {
|
||||
for b.Loop() {
|
||||
val, _ := Parse(url.Values{"key": []string{"123.456"}}, "key", ParseFloat)
|
||||
_ = val
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParsePointer(b *testing.B) {
|
||||
parser := func(q url.Values, key string) (*float64, error) {
|
||||
vals, ok := q[key]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
if len(vals) > 1 {
|
||||
return nil, ErrInvalidQueryParams
|
||||
}
|
||||
s := vals[0]
|
||||
if s == "" {
|
||||
return nil, nil // not nil, but empty
|
||||
}
|
||||
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
val, _ := parser(url.Values{"key": []string{"123.456"}}, "key")
|
||||
_ = val
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package nullable contains nullable types used to handle
|
||||
// scenarios where default values can't be used to indicate empty,
|
||||
// and we want to avoid using pointers for that.
|
||||
package nullable
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nullable
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var ErrInvalidQueryParams = errors.New("invalid query parameters")
|
||||
|
||||
func Parse[T any](q url.Values, key string, parser FromString[T]) (Value[T], error) {
|
||||
vals, ok := q[key]
|
||||
if !ok {
|
||||
return Value[T]{}, nil
|
||||
}
|
||||
if len(vals) > 1 {
|
||||
return Value[T]{}, ErrInvalidQueryParams
|
||||
}
|
||||
s := vals[0]
|
||||
if s == "" {
|
||||
// The actual value is sent in query, so nullable is set, but empty.
|
||||
return Value[T]{Set: true}, nil
|
||||
}
|
||||
return parser(s)
|
||||
}
|
||||
|
||||
func ParseString(s string) (Value[string], error) {
|
||||
return Value[string]{Set: true, Value: s}, nil
|
||||
}
|
||||
|
||||
func ParseInt(s string) (Value[int], error) {
|
||||
val, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return Value[int]{}, err
|
||||
}
|
||||
return Value[int]{Set: true, Value: val}, nil
|
||||
}
|
||||
|
||||
func ParseFloat(s string) (Value[float64], error) {
|
||||
val, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return Value[float64]{}, err
|
||||
}
|
||||
return Value[float64]{Set: true, Value: val}, nil
|
||||
}
|
||||
|
||||
func ParseBool(s string) (Value[bool], error) {
|
||||
b, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return Value[bool]{}, err
|
||||
}
|
||||
return Value[bool]{Set: true, Value: b}, nil
|
||||
}
|
||||
|
||||
func ParseU16(s string) (Value[uint16], error) {
|
||||
val, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return Value[uint16]{}, err
|
||||
}
|
||||
return Value[uint16]{Set: true, Value: uint16(val)}, nil
|
||||
}
|
||||
|
||||
func ParseU64(s string) (Value[uint64], error) {
|
||||
val, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return Value[uint64]{}, err
|
||||
}
|
||||
return Value[uint64]{Set: true, Value: val}, nil
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package nullable
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseHelpers(t *testing.T) {
|
||||
t.Run("ParseString", func(t *testing.T) {
|
||||
val, err := ParseString("hello")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, val.Set)
|
||||
assert.Equal(t, "hello", val.Value)
|
||||
})
|
||||
|
||||
t.Run("ParseInt", func(t *testing.T) {
|
||||
val, err := ParseInt("42")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, val.Set)
|
||||
assert.Equal(t, 42, val.Value)
|
||||
|
||||
val, err = ParseInt("notanint")
|
||||
assert.Error(t, err)
|
||||
assert.False(t, val.Set)
|
||||
})
|
||||
|
||||
t.Run("ParseFloat", func(t *testing.T) {
|
||||
val, err := ParseFloat("3.14")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, val.Set)
|
||||
assert.Equal(t, 3.14, val.Value)
|
||||
})
|
||||
|
||||
t.Run("ParseBool", func(t *testing.T) {
|
||||
val, err := ParseBool("true")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, val.Set)
|
||||
assert.True(t, val.Value)
|
||||
|
||||
val, err = ParseBool("false")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, val.Set)
|
||||
assert.False(t, val.Value)
|
||||
|
||||
val, err = ParseBool("maybe")
|
||||
assert.Error(t, err)
|
||||
assert.False(t, val.Set)
|
||||
})
|
||||
|
||||
t.Run("ParseU16", func(t *testing.T) {
|
||||
val, err := ParseU16("65535")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, val.Set)
|
||||
assert.Equal(t, uint16(65535), val.Value)
|
||||
|
||||
val, err = ParseU16("70000")
|
||||
assert.Error(t, err)
|
||||
assert.False(t, val.Set)
|
||||
})
|
||||
|
||||
t.Run("ParseU64", func(t *testing.T) {
|
||||
val, err := ParseU64("1234567890")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, val.Set)
|
||||
assert.Equal(t, uint64(1234567890), val.Value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseQueryParam(t *testing.T) {
|
||||
type useCase struct {
|
||||
name string
|
||||
query url.Values
|
||||
key string
|
||||
parser func(string) (Value[int], error)
|
||||
expect Value[int]
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
cases := []useCase{
|
||||
{
|
||||
name: "missing key",
|
||||
query: url.Values{},
|
||||
key: "limit",
|
||||
parser: ParseInt,
|
||||
expect: Value[int]{Set: false},
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
query: url.Values{"limit": {""}},
|
||||
key: "limit",
|
||||
parser: ParseInt,
|
||||
expect: Value[int]{Set: true},
|
||||
},
|
||||
{
|
||||
name: "valid int",
|
||||
query: url.Values{"limit": {"10"}},
|
||||
key: "limit",
|
||||
parser: ParseInt,
|
||||
expect: Value[int]{Set: true, Value: 10},
|
||||
},
|
||||
{
|
||||
name: "invalid int",
|
||||
query: url.Values{"limit": {"bad"}},
|
||||
key: "limit",
|
||||
parser: ParseInt,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple values",
|
||||
query: url.Values{"limit": {"1", "2"}},
|
||||
key: "limit",
|
||||
parser: ParseInt,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, uc := range cases {
|
||||
t.Run(uc.name, func(t *testing.T) {
|
||||
val, err := Parse(uc.query, uc.key, uc.parser)
|
||||
if uc.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uc.expect, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nullable
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Value type is used to represent difference betweeen an
|
||||
// intentionally omitted value and default type value.
|
||||
type Value[T any] struct {
|
||||
Set bool
|
||||
Value T
|
||||
}
|
||||
|
||||
// FromString[T any] represents a parser function. It is used to avoid
|
||||
// a single parser for all nullables for improved readability and performance.
|
||||
// FromString should always return Nullable with Set=true, error otherwise.
|
||||
type FromString[T any] func(string) (Value[T], error)
|
||||
|
||||
// MarshalJSON encodes the value if set, otherwise returns `null`.
|
||||
func (n Value[T]) MarshalJSON() ([]byte, error) {
|
||||
if !n.Set {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(n.Value)
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes JSON and sets the value and Set flag.
|
||||
func (n *Value[T]) UnmarshalJSON(data []byte) error {
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
n.Set = false
|
||||
var empty T
|
||||
n.Value = empty
|
||||
return nil
|
||||
}
|
||||
|
||||
var val T
|
||||
if err := json.Unmarshal(data, &val); err != nil {
|
||||
return fmt.Errorf("nullable: failed to unmarshal: %w", err)
|
||||
}
|
||||
n.Value = val
|
||||
n.Set = true
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user