NOISSUE - Add Nullable type for optional values handling (#2877)

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>
This commit is contained in:
Dušan Borovčanin
2025-05-20 18:45:24 +02:00
committed by GitHub
parent d6c260b803
commit 7f4633a3d1
8 changed files with 469 additions and 35 deletions
+3 -8
View File
@@ -12,6 +12,7 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/internal/nullable"
"github.com/absmach/supermq/pkg/errors"
"github.com/go-chi/chi/v5"
)
@@ -130,17 +131,11 @@ func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
var groupPtr *string
groupID, err := apiutil.ReadStringQuery(r, api.GroupKey, "")
groupID, err := nullable.Parse(r.URL.Query(), api.GroupKey, nullable.ParseString)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
if r.URL.Query().Has(api.GroupKey) {
groupPtr = &groupID
}
clientID, err := apiutil.ReadStringQuery(r, api.ClientKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
@@ -165,7 +160,7 @@ func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error)
Dir: dir,
Offset: offset,
Limit: limit,
Group: groupPtr,
Group: groupID,
Client: clientID,
ID: id,
},
+20 -19
View File
@@ -7,6 +7,7 @@ import (
"context"
"time"
"github.com/absmach/supermq/internal/nullable"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/roles"
@@ -45,25 +46,25 @@ type Channel struct {
}
type Page struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tag string `json:"tag,omitempty"`
Status Status `json:"status,omitempty"`
Group *string `json:"group,omitempty"`
Client string `json:"client,omitempty"`
ConnectionType string `json:"connection_type,omitempty"`
RoleName string `json:"role_name,omitempty"`
RoleID string `json:"role_id,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
IDs []string `json:"-"`
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tag string `json:"tag,omitempty"`
Status Status `json:"status,omitempty"`
Group nullable.Value[string] `json:"group,omitempty"`
Client string `json:"client,omitempty"`
ConnectionType string `json:"connection_type,omitempty"`
RoleName string `json:"role_name,omitempty"`
RoleID string `json:"role_id,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
IDs []string `json:"-"`
}
// ChannelsPage contains page related metadata as well as list of channels that
+7 -8
View File
@@ -1252,13 +1252,12 @@ func PageQuery(pm channels.Page) (string, error) {
if pm.Domain != "" {
query = append(query, "c.domain_id = :domain_id")
}
if pm.Group != nil {
switch *pm.Group {
case "":
query = append(query, "c.parent_group_id = '' ")
default:
if pm.Group.Set {
switch {
case pm.Group.Value != "":
query = append(query, "c.parent_group_path <@ (SELECT path from groups where id = :group_id) ")
default:
query = append(query, "c.parent_group_id = '' ")
}
}
@@ -1321,7 +1320,7 @@ func toDBChannelsPage(pm channels.Page) (dbChannelsPage, error) {
Metadata: data,
Tag: pm.Tag,
Status: pm.Status,
GroupID: pm.Group,
GroupID: sql.NullString{Valid: pm.Group.Set, String: pm.Group.Value},
ClientID: pm.Client,
ConnType: pm.ConnectionType,
RoleName: pm.RoleName,
@@ -1340,7 +1339,7 @@ type dbChannelsPage struct {
Metadata []byte `db:"metadata"`
Tag string `db:"tag"`
Status channels.Status `db:"status"`
GroupID *string `db:"group_id"`
GroupID sql.NullString `db:"group_id"`
ClientID string `db:"client_id"`
ConnType string `db:"type"`
RoleName string `db:"role_name"`
+180
View File
@@ -0,0 +1,180 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nullable
import (
"net/url"
"strconv"
"testing"
)
func BenchmarkPointerString(b *testing.B) {
p := func() *string {
x := "test"
return &x
}()
for b.Loop() {
if p != nil {
_ = *p + "test"
}
}
}
func BenchmarkNullableString(b *testing.B) {
n := func() Value[string] {
return Value[string]{Set: true, Value: "test"}
}()
for b.Loop() {
if n.Set {
_ = n.Value + "test"
}
}
}
func BenchmarkNullableStringParse(b *testing.B) {
for b.Loop() {
n, _ := ParseString("123")
_ = n.Value
}
}
func BenchmarkPointerStringParse(b *testing.B) {
parser := func(s string) (*string, error) {
return &s, nil
}
for b.Loop() {
n, _ := parser("123")
_ = *n
}
}
func BenchmarkPointerInt(b *testing.B) {
p := func() *int {
x := 42
return &x
}()
for b.Loop() {
if p != nil {
_ = *p + 1
}
}
}
func BenchmarkNullableInt(b *testing.B) {
n := func() Value[int] {
return Value[int]{Set: true, Value: 42}
}()
for b.Loop() {
if n.Set {
_ = n.Value + 1
}
}
}
func BenchmarkNullableIntParse(b *testing.B) {
for b.Loop() {
n, _ := ParseInt("123")
_ = n.Value
}
}
func BenchmarkPointerIntParse(b *testing.B) {
parser := func(s string) (*int, error) {
v, err := strconv.Atoi(s)
if err != nil {
return nil, err
}
return &v, nil
}
for b.Loop() {
n, _ := parser("123")
_ = *n
}
}
func BenchmarkPointerFloat(b *testing.B) {
p := func() *float64 {
x := float64(42)
return &x
}()
for b.Loop() {
if p != nil {
_ = *p + 1
}
}
}
func BenchmarkNullableFloat(b *testing.B) {
n := func() Value[float64] {
return Value[float64]{Set: true, Value: 42}
}()
for b.Loop() {
if n.Set {
_ = n.Value + 1
}
}
}
func BenchmarkNullableFloatParse(b *testing.B) {
for b.Loop() {
n, _ := ParseFloat("123.45")
_ = n.Value
}
}
func BenchmarkPointerFloatParse(b *testing.B) {
parser := func(s string) (*float64, error) {
v, err := strconv.ParseFloat(s, 10)
if err != nil {
return nil, err
}
return &v, nil
}
for b.Loop() {
n, _ := parser("123.45")
_ = *n
}
}
func BenchmarkParseNullable(b *testing.B) {
for b.Loop() {
val, _ := Parse(url.Values{"key": []string{"123.456"}}, "key", ParseFloat)
_ = val
}
}
func BenchmarkParsePointer(b *testing.B) {
parser := func(q url.Values, key string) (*float64, error) {
vals, ok := q[key]
if !ok {
return nil, nil
}
if len(vals) > 1 {
return nil, ErrInvalidQueryParams
}
s := vals[0]
if s == "" {
return nil, nil // not nil, but empty
}
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, err
}
return &v, nil
}
for b.Loop() {
val, _ := parser(url.Values{"key": []string{"123.456"}}, "key")
_ = val
}
}
+7
View File
@@ -0,0 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package nullable contains nullable types used to handle
// scenarios where default values can't be used to indicate empty,
// and we want to avoid using pointers for that.
package nullable
+72
View File
@@ -0,0 +1,72 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nullable
import (
"errors"
"net/url"
"strconv"
)
var ErrInvalidQueryParams = errors.New("invalid query parameters")
func Parse[T any](q url.Values, key string, parser FromString[T]) (Value[T], error) {
vals, ok := q[key]
if !ok {
return Value[T]{}, nil
}
if len(vals) > 1 {
return Value[T]{}, ErrInvalidQueryParams
}
s := vals[0]
if s == "" {
// The actual value is sent in query, so nullable is set, but empty.
return Value[T]{Set: true}, nil
}
return parser(s)
}
func ParseString(s string) (Value[string], error) {
return Value[string]{Set: true, Value: s}, nil
}
func ParseInt(s string) (Value[int], error) {
val, err := strconv.Atoi(s)
if err != nil {
return Value[int]{}, err
}
return Value[int]{Set: true, Value: val}, nil
}
func ParseFloat(s string) (Value[float64], error) {
val, err := strconv.ParseFloat(s, 64)
if err != nil {
return Value[float64]{}, err
}
return Value[float64]{Set: true, Value: val}, nil
}
func ParseBool(s string) (Value[bool], error) {
b, err := strconv.ParseBool(s)
if err != nil {
return Value[bool]{}, err
}
return Value[bool]{Set: true, Value: b}, nil
}
func ParseU16(s string) (Value[uint16], error) {
val, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return Value[uint16]{}, err
}
return Value[uint16]{Set: true, Value: uint16(val)}, nil
}
func ParseU64(s string) (Value[uint64], error) {
val, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return Value[uint64]{}, err
}
return Value[uint64]{Set: true, Value: val}, nil
}
+132
View File
@@ -0,0 +1,132 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nullable
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseHelpers(t *testing.T) {
t.Run("ParseString", func(t *testing.T) {
val, err := ParseString("hello")
assert.NoError(t, err)
assert.True(t, val.Set)
assert.Equal(t, "hello", val.Value)
})
t.Run("ParseInt", func(t *testing.T) {
val, err := ParseInt("42")
assert.NoError(t, err)
assert.True(t, val.Set)
assert.Equal(t, 42, val.Value)
val, err = ParseInt("notanint")
assert.Error(t, err)
assert.False(t, val.Set)
})
t.Run("ParseFloat", func(t *testing.T) {
val, err := ParseFloat("3.14")
assert.NoError(t, err)
assert.True(t, val.Set)
assert.Equal(t, 3.14, val.Value)
})
t.Run("ParseBool", func(t *testing.T) {
val, err := ParseBool("true")
assert.NoError(t, err)
assert.True(t, val.Set)
assert.True(t, val.Value)
val, err = ParseBool("false")
assert.NoError(t, err)
assert.True(t, val.Set)
assert.False(t, val.Value)
val, err = ParseBool("maybe")
assert.Error(t, err)
assert.False(t, val.Set)
})
t.Run("ParseU16", func(t *testing.T) {
val, err := ParseU16("65535")
assert.NoError(t, err)
assert.True(t, val.Set)
assert.Equal(t, uint16(65535), val.Value)
val, err = ParseU16("70000")
assert.Error(t, err)
assert.False(t, val.Set)
})
t.Run("ParseU64", func(t *testing.T) {
val, err := ParseU64("1234567890")
assert.NoError(t, err)
assert.True(t, val.Set)
assert.Equal(t, uint64(1234567890), val.Value)
})
}
func TestParseQueryParam(t *testing.T) {
type useCase struct {
name string
query url.Values
key string
parser func(string) (Value[int], error)
expect Value[int]
expectErr bool
}
cases := []useCase{
{
name: "missing key",
query: url.Values{},
key: "limit",
parser: ParseInt,
expect: Value[int]{Set: false},
},
{
name: "empty value",
query: url.Values{"limit": {""}},
key: "limit",
parser: ParseInt,
expect: Value[int]{Set: true},
},
{
name: "valid int",
query: url.Values{"limit": {"10"}},
key: "limit",
parser: ParseInt,
expect: Value[int]{Set: true, Value: 10},
},
{
name: "invalid int",
query: url.Values{"limit": {"bad"}},
key: "limit",
parser: ParseInt,
expectErr: true,
},
{
name: "multiple values",
query: url.Values{"limit": {"1", "2"}},
key: "limit",
parser: ParseInt,
expectErr: true,
},
}
for _, uc := range cases {
t.Run(uc.name, func(t *testing.T) {
val, err := Parse(uc.query, uc.key, uc.parser)
if uc.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, uc.expect, val)
}
})
}
}
+48
View File
@@ -0,0 +1,48 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nullable
import (
"bytes"
"encoding/json"
"fmt"
)
// Value type is used to represent difference betweeen an
// intentionally omitted value and default type value.
type Value[T any] struct {
Set bool
Value T
}
// FromString[T any] represents a parser function. It is used to avoid
// a single parser for all nullables for improved readability and performance.
// FromString should always return Nullable with Set=true, error otherwise.
type FromString[T any] func(string) (Value[T], error)
// MarshalJSON encodes the value if set, otherwise returns `null`.
func (n Value[T]) MarshalJSON() ([]byte, error) {
if !n.Set {
return []byte("null"), nil
}
return json.Marshal(n.Value)
}
// UnmarshalJSON decodes JSON and sets the value and Set flag.
func (n *Value[T]) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, []byte("null")) {
n.Set = false
var empty T
n.Value = empty
return nil
}
var val T
if err := json.Unmarshal(data, &val); err != nil {
return fmt.Errorf("nullable: failed to unmarshal: %w", err)
}
n.Value = val
n.Set = true
return nil
}