NOISSUE - Add Test for apiutil package (#195)

* Test for parsing and handling query parameters

The provided code includes test cases written that cover various scenarios for parsing and handling query parameters in HTTP requests.
These scenarios include valid and invalid queries, empty queries, and multiple queries with the same key.

Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com>

* replace ReadUintQuery with ReadNumQuery

Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com>

---------

Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com>
This commit is contained in:
b1ackd0t
2023-12-20 18:41:07 +03:00
committed by GitHub
parent cb3161ce0e
commit 0016d67055
8 changed files with 509 additions and 60 deletions
+2 -2
View File
@@ -182,12 +182,12 @@ func decodeUpdateConnRequest(_ context.Context, r *http.Request) (interface{}, e
}
func decodeListRequest(_ context.Context, r *http.Request) (interface{}, error) {
o, err := apiutil.ReadUintQuery(r, offsetKey, defOffset)
o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
l, err := apiutil.ReadUintQuery(r, limitKey, defLimit)
l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
+2 -2
View File
@@ -88,11 +88,11 @@ func encodeResponse(_ context.Context, w http.ResponseWriter, response interface
}
func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) {
l, err := apiutil.ReadUintQuery(r, limitKey, defLimit)
l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
o, err := apiutil.ReadUintQuery(r, offsetKey, defOffset)
o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
+2 -2
View File
@@ -116,13 +116,13 @@ func decodeList(_ context.Context, r *http.Request) (interface{}, error) {
req.contact = vals[0]
}
offset, err := apiutil.ReadUintQuery(r, offsetKey, defOffset)
offset, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
if err != nil {
return listSubsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
req.offset = uint(offset)
limit, err := apiutil.ReadUintQuery(r, limitKey, defLimit)
limit, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
if err != nil {
return listSubsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
+112
View File
@@ -0,0 +1,112 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package apiutil_test
import (
"net/http"
"testing"
"github.com/absmach/magistrala/internal/apiutil"
"github.com/stretchr/testify/assert"
)
func TestExtractBearerToken(t *testing.T) {
cases := []struct {
desc string
request *http.Request
token string
}{
{
desc: "valid bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {"Bearer 123"},
},
},
token: "123",
},
{
desc: "invalid bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {"123"},
},
},
token: "",
},
{
desc: "empty bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {""},
},
},
token: "",
},
{
desc: "empty header",
request: &http.Request{
Header: map[string][]string{},
},
token: "",
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
token := apiutil.ExtractBearerToken(c.request)
assert.Equal(t, c.token, token)
})
}
}
func TestExtractThingKey(t *testing.T) {
cases := []struct {
desc string
request *http.Request
token string
}{
{
desc: "valid bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {"Thing 123"},
},
},
token: "123",
},
{
desc: "invalid bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {"123"},
},
},
token: "",
},
{
desc: "empty bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {""},
},
},
token: "",
},
{
desc: "empty header",
request: &http.Request{
Header: map[string][]string{},
},
token: "",
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
token := apiutil.ExtractThingKey(c.request)
assert.Equal(t, c.token, token)
})
}
}
+17 -45
View File
@@ -24,26 +24,6 @@ func LoggingErrorEncoder(logger mglog.Logger, enc kithttp.ErrorEncoder) kithttp.
}
}
// ReadUintQuery reads the value of uint64 http query parameters for a given key.
func ReadUintQuery(r *http.Request, key string, def uint64) (uint64, error) {
vals := r.URL.Query()[key]
if len(vals) > 1 {
return 0, ErrInvalidQueryParams
}
if len(vals) == 0 {
return def, nil
}
strval := vals[0]
val, err := strconv.ParseUint(strval, 10, 64)
if err != nil {
return 0, ErrInvalidQueryParams
}
return val, nil
}
// ReadStringQuery reads the value of string http query parameters for a given key.
func ReadStringQuery(r *http.Request, key, def string) (string, error) {
vals := r.URL.Query()[key]
@@ -91,32 +71,12 @@ func ReadBoolQuery(r *http.Request, key string, def bool) (bool, error) {
b, err := strconv.ParseBool(vals[0])
if err != nil {
return false, ErrInvalidQueryParams
return false, errors.Wrap(ErrInvalidQueryParams, err)
}
return b, nil
}
// ReadFloatQuery reads the value of float64 http query parameters for a given key.
func ReadFloatQuery(r *http.Request, key string, def float64) (float64, error) {
vals := r.URL.Query()[key]
if len(vals) > 1 {
return 0, ErrInvalidQueryParams
}
if len(vals) == 0 {
return def, nil
}
fval := vals[0]
val, err := strconv.ParseFloat(fval, 64)
if err != nil {
return 0, ErrInvalidQueryParams
}
return val, nil
}
type number interface {
int64 | float64 | uint16 | uint64
}
@@ -135,16 +95,28 @@ func ReadNumQuery[N number](r *http.Request, key string, def N) (N, error) {
switch any(def).(type) {
case int64:
v, err := strconv.ParseInt(val, 10, 64)
return N(v), err
if err != nil {
return 0, errors.Wrap(ErrInvalidQueryParams, err)
}
return N(v), nil
case uint64:
v, err := strconv.ParseUint(val, 10, 64)
return N(v), err
if err != nil {
return 0, errors.Wrap(ErrInvalidQueryParams, err)
}
return N(v), nil
case uint16:
v, err := strconv.ParseUint(val, 10, 16)
return N(v), err
if err != nil {
return 0, errors.Wrap(ErrInvalidQueryParams, err)
}
return N(v), nil
case float64:
v, err := strconv.ParseFloat(val, 64)
return N(v), err
if err != nil {
return 0, errors.Wrap(ErrInvalidQueryParams, err)
}
return N(v), nil
default:
return def, nil
}
+365
View File
@@ -0,0 +1,365 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package apiutil_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/absmach/magistrala/internal/apiutil"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestReadStringQuery(t *testing.T) {
cases := []struct {
desc string
url string
key string
ret string
err error
}{
{
desc: "valid string query",
url: "http://localhost:8080/?key=test",
key: "key",
ret: "test",
err: nil,
},
{
desc: "empty string query",
url: "http://localhost:8080/",
key: "key",
ret: "",
err: nil,
},
{
desc: "multiple string query",
url: "http://localhost:8080/?key=test&key=random",
key: "key",
ret: "",
err: apiutil.ErrInvalidQueryParams,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
parsedURL, err := url.Parse(c.url)
assert.NoError(t, err)
r := &http.Request{URL: parsedURL}
ret, err := apiutil.ReadStringQuery(r, c.key, "")
assert.Equal(t, c.err, err)
assert.Equal(t, c.ret, ret)
})
}
}
func TestReadMetadataQuery(t *testing.T) {
cases := []struct {
desc string
url string
key string
ret map[string]interface{}
err error
}{
{
desc: "valid metadata query",
url: "http://localhost:8080/?key={\"test\":\"test\"}",
key: "key",
ret: map[string]interface{}{"test": "test"},
err: nil,
},
{
desc: "empty metadata query",
url: "http://localhost:8080/",
key: "key",
ret: nil,
err: nil,
},
{
desc: "multiple metadata query",
url: "http://localhost:8080/?key={\"test\":\"test\"}&key={\"random\":\"random\"}",
key: "key",
ret: nil,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "invalid metadata query",
url: "http://localhost:8080/?key=abc",
key: "key",
ret: nil,
err: apiutil.ErrInvalidQueryParams,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
parsedURL, err := url.Parse(c.url)
assert.NoError(t, err)
r := &http.Request{URL: parsedURL}
ret, err := apiutil.ReadMetadataQuery(r, c.key, nil)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
assert.Equal(t, c.ret, ret)
})
}
}
func TestReadBoolQuery(t *testing.T) {
cases := []struct {
desc string
url string
key string
ret bool
err error
}{
{
desc: "valid bool query",
url: "http://localhost:8080/?key=true",
key: "key",
ret: true,
err: nil,
},
{
desc: "valid bool query",
url: "http://localhost:8080/?key=false",
key: "key",
ret: false,
err: nil,
},
{
desc: "invalid bool query",
url: "http://localhost:8080/?key=abc",
key: "key",
ret: false,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "empty bool query",
url: "http://localhost:8080/",
key: "key",
ret: false,
err: nil,
},
{
desc: "multiple bool query",
url: "http://localhost:8080/?key=true&key=false",
key: "key",
ret: false,
err: apiutil.ErrInvalidQueryParams,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
parsedURL, err := url.Parse(c.url)
assert.NoError(t, err)
r := &http.Request{URL: parsedURL}
ret, err := apiutil.ReadBoolQuery(r, c.key, false)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
assert.Equal(t, c.ret, ret)
})
}
}
func TestReadNumQuery(t *testing.T) {
cases := []struct {
desc string
url string
key string
numType string
ret interface{}
err error
}{
{
desc: "valid int64 query",
url: "http://localhost:8080/?key=123",
key: "key",
numType: "int64",
ret: int64(123),
err: nil,
},
{
desc: "valid float64 query",
url: "http://localhost:8080/?key=1.23",
key: "key",
numType: "float64",
ret: float64(1.23),
err: nil,
},
{
desc: "valid uint64 query",
url: "http://localhost:8080/?key=123",
key: "key",
numType: "uint64",
ret: uint64(123),
err: nil,
},
{
desc: "valid uint16 query",
url: "http://localhost:8080/?key=123",
key: "key",
numType: "uint16",
ret: uint16(123),
err: nil,
},
{
desc: "invalid int64 query",
url: "http://localhost:8080/?key=abc",
key: "key",
numType: "int64",
ret: int64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "invalid float64 query",
url: "http://localhost:8080/?key=abc",
key: "key",
numType: "float64",
ret: float64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "invalid uint64 query",
url: "http://localhost:8080/?key=abc",
key: "key",
numType: "uint64",
ret: uint64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "invalid uint16 query",
url: "http://localhost:8080/?key=abc",
key: "key",
numType: "uint16",
ret: uint16(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "empty int64 query",
url: "http://localhost:8080/",
key: "key",
numType: "int64",
ret: int64(0),
err: nil,
},
{
desc: "empty float64 query",
url: "http://localhost:8080/",
key: "key",
numType: "float64",
ret: float64(0),
err: nil,
},
{
desc: "empty uint16 query",
url: "http://localhost:8080/",
key: "key",
numType: "uint16",
ret: uint16(0),
err: nil,
},
{
desc: "empty uint64 query",
url: "http://localhost:8080/",
key: "key",
numType: "uint64",
ret: uint64(0),
err: nil,
},
{
desc: "multiple int64 query",
url: "http://localhost:8080/?key=123&key=456",
key: "key",
numType: "int64",
ret: int64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "multiple float64 query",
url: "http://localhost:8080/?key=1.23&key=4.56",
key: "key",
numType: "float64",
ret: float64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "multiple uint16 query",
url: "http://localhost:8080/?key=123&key=456",
key: "key",
numType: "uint16",
ret: uint16(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "multiple uint64 query",
url: "http://localhost:8080/?key=123&key=456",
key: "key",
numType: "uint64",
ret: uint64(0),
err: apiutil.ErrInvalidQueryParams,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
parsedURL, err := url.Parse(c.url)
assert.NoError(t, err)
r := &http.Request{URL: parsedURL}
var ret interface{}
switch c.numType {
case "int64":
ret, err = apiutil.ReadNumQuery[int64](r, c.key, 0)
case "float64":
ret, err = apiutil.ReadNumQuery[float64](r, c.key, 0)
case "uint64":
ret, err = apiutil.ReadNumQuery[uint64](r, c.key, 0)
case "uint16":
ret, err = apiutil.ReadNumQuery[uint16](r, c.key, 0)
}
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
assert.Equal(t, c.ret, ret)
})
}
}
func TestLoggingErrorEncoder(t *testing.T) {
logger := mglog.NewMock()
cases := []struct {
desc string
err error
}{
{
desc: "error contains ErrValidation",
err: errors.Wrap(apiutil.ErrValidation, errors.ErrAuthentication),
},
{
desc: "error does not contain ErrValidation",
err: errors.ErrAuthentication,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
encCalled := false
encFunc := func(ctx context.Context, err error, w http.ResponseWriter) {
encCalled = true
}
errorEncoder := apiutil.LoggingErrorEncoder(logger, encFunc)
errorEncoder(context.Background(), c.err, httptest.NewRecorder())
assert.True(t, encCalled)
})
}
}
+5 -5
View File
@@ -70,12 +70,12 @@ func MakeHandler(svc readers.MessageRepository, uauth magistrala.AuthServiceClie
}
func decodeList(_ context.Context, r *http.Request) (interface{}, error) {
offset, err := apiutil.ReadUintQuery(r, offsetKey, defOffset)
offset, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
limit, err := apiutil.ReadUintQuery(r, limitKey, defLimit)
limit, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
@@ -105,7 +105,7 @@ func decodeList(_ context.Context, r *http.Request) (interface{}, error) {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
v, err := apiutil.ReadFloatQuery(r, valueKey, 0)
v, err := apiutil.ReadNumQuery[float64](r, valueKey, 0)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
@@ -130,12 +130,12 @@ func decodeList(_ context.Context, r *http.Request) (interface{}, error) {
return nil, err
}
from, err := apiutil.ReadFloatQuery(r, fromKey, 0)
from, err := apiutil.ReadNumQuery[float64](r, fromKey, 0)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
to, err := apiutil.ReadFloatQuery(r, toKey, 0)
to, err := apiutil.ReadNumQuery[float64](r, toKey, 0)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
+4 -4
View File
@@ -123,12 +123,12 @@ func decodeView(_ context.Context, r *http.Request) (interface{}, error) {
}
func decodeList(_ context.Context, r *http.Request) (interface{}, error) {
l, err := apiutil.ReadUintQuery(r, limitKey, defLimit)
l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
o, err := apiutil.ReadUintQuery(r, offsetKey, defOffset)
o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
@@ -155,12 +155,12 @@ func decodeList(_ context.Context, r *http.Request) (interface{}, error) {
}
func decodeListStates(_ context.Context, r *http.Request) (interface{}, error) {
l, err := apiutil.ReadUintQuery(r, limitKey, defLimit)
l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
o, err := apiutil.ReadUintQuery(r, offsetKey, defOffset)
o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}