NOISSUE - Refactor error handling and SDK error testing (#263)

Improve error handling and SDK error testing. The code now includes signal handling, error wrapping, checking for contained errors, and testing HTTP error responses. Additionally, assertions have been added to ensure the expected behavior is met.

Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com>
This commit is contained in:
b1ackd0t
2024-01-08 13:33:16 +03:00
committed by GitHub
parent b3e206321c
commit 22ecbc24ee
5 changed files with 418 additions and 54 deletions
+16 -5
View File
@@ -12,6 +12,8 @@ import (
"net/http"
"net/url"
"os"
"os/signal"
"syscall"
"time"
"github.com/absmach/magistrala"
@@ -196,11 +198,7 @@ func main() {
})
g.Go(func() error {
if sig := errors.SignalHandler(ctx); sig != nil {
cancel()
logger.Info(fmt.Sprintf("mProxy shutdown by signal: %s", sig))
}
return nil
return stopSignalHandler(ctx, cancel, logger)
})
if err := g.Wait(); err != nil {
@@ -264,3 +262,16 @@ func healthcheck(cfg config) func() error {
return nil
}
}
func stopSignalHandler(ctx context.Context, cancel context.CancelFunc, logger mglog.Logger) error {
c := make(chan os.Signal, 2)
signal.Notify(c, syscall.SIGINT, syscall.SIGABRT)
select {
case sig := <-c:
defer cancel()
logger.Info(fmt.Sprintf("%s service shutdown by signal: %s", svcName, sig))
return nil
case <-ctx.Done():
return nil
}
}
+8 -24
View File
@@ -4,12 +4,7 @@
package errors
import (
"context"
"encoding/json"
"fmt"
"os"
"os/signal"
"syscall"
)
// Error specifies an API that must be fullfiled by error type.
@@ -35,6 +30,14 @@ type customError struct {
err Error
}
// New returns an Error that formats as the given text.
func New(text string) Error {
return &customError{
msg: text,
err: nil,
}
}
func (ce *customError) Error() string {
if ce == nil {
return ""
@@ -123,22 +126,3 @@ func cast(err error) Error {
err: nil,
}
}
// New returns an Error that formats as the given text.
func New(text string) Error {
return &customError{
msg: text,
err: nil,
}
}
func SignalHandler(ctx context.Context) error {
c := make(chan os.Signal, 2)
signal.Notify(c, syscall.SIGINT, syscall.SIGABRT)
select {
case sig := <-c:
return fmt.Errorf("%s", sig)
case <-ctx.Done():
return nil
}
}
+180 -25
View File
@@ -4,6 +4,7 @@
package errors_test
import (
nerrors "errors"
"fmt"
"strconv"
"testing"
@@ -18,39 +19,63 @@ var (
err0 = errors.New("0")
err1 = errors.New("1")
err2 = errors.New("2")
nat = nerrors.New("native error")
)
func TestError(t *testing.T) {
cases := []struct {
desc string
err error
msg string
desc string
err error
msg string
bytes []byte
bytesErr error
}{
{
desc: "level 0 wrapped error",
err: err0,
msg: "0",
desc: "level 0 wrapped error",
err: err0,
msg: "0",
bytes: []byte(`{"error":"","message":"0"}`),
bytesErr: nil,
},
{
desc: "level 1 wrapped error",
err: wrap(1),
msg: message(1),
desc: "level 1 wrapped error",
err: wrap(1),
msg: message(1),
bytes: []byte(`{"error":"0","message":"1"}`),
bytesErr: nil,
},
{
desc: "level 2 wrapped error",
err: wrap(2),
msg: message(2),
desc: "level 2 wrapped error",
err: wrap(2),
msg: message(2),
bytes: []byte(`{"error":"1","message":"2"}`),
bytesErr: nil,
},
{
desc: fmt.Sprintf("level %d wrapped error", level),
err: wrap(level),
msg: message(level),
desc: fmt.Sprintf("level %d wrapped error", level),
err: wrap(level),
msg: message(level),
bytes: []byte(`{"error":"9","message":"` + strconv.Itoa(level) + `"}`),
bytesErr: nil,
},
{
desc: "nil error",
err: errors.New(""),
msg: "",
bytes: []byte(`{"error":"","message":""}`),
bytesErr: nil,
},
}
for _, tc := range cases {
errMsg := tc.err.Error()
assert.Equal(t, tc.msg, errMsg, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.msg, errMsg))
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
errMsg := c.err.Error()
assert.Equal(t, c.msg, errMsg)
err := c.err.(errors.Error)
data, derr := err.MarshalJSON()
assert.Equal(t, c.bytesErr, derr)
assert.Equal(t, c.bytes, data)
})
}
}
@@ -115,10 +140,36 @@ func TestContains(t *testing.T) {
contained: wrap(level / 2),
contains: false,
},
{
desc: "native error contains error",
container: nat,
contained: err0,
contains: false,
},
{
desc: "res of errors.Wrap(err1, errors.New('')) contains err1",
container: errors.Wrap(err1, nat),
contained: err1,
contains: true,
},
{
desc: "error contains native error",
container: err0,
contained: nat,
contains: false,
},
{
desc: "res of errors.Wrap(errors.New(''), err0) contains err0",
container: errors.Wrap(nat, err0),
contained: err0,
contains: true,
},
}
for _, tc := range cases {
contains := errors.Contains(tc.container, tc.contained)
assert.Equal(t, tc.contains, contains, fmt.Sprintf("%s: expected %v to contain %v\n", tc.desc, tc.container, tc.contained))
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
contains := errors.Contains(c.container, c.contained)
assert.Equal(t, c.contains, contains)
})
}
}
@@ -172,12 +223,116 @@ func TestWrap(t *testing.T) {
contained: err0,
contains: false,
},
{
desc: "err0 wraps native error",
wrapper: err0,
wrapped: nat,
contained: nat,
contains: true,
},
{
desc: "nil wraps native error",
wrapper: nil,
wrapped: nat,
contained: nat,
contains: false,
},
{
desc: "native error wraps err0",
wrapper: nat,
wrapped: err0,
contained: err0,
contains: true,
},
{
desc: "native error wraps nil",
wrapper: nat,
wrapped: nil,
contained: nil,
contains: false,
},
{
desc: "err0 wraps err1 wraps native error",
wrapper: err0,
wrapped: errors.Wrap(err1, nat),
contained: nat,
contains: true,
},
{
desc: "native error wraps err1 wraps err0",
wrapper: nat,
wrapped: errors.Wrap(err1, err0),
contained: err0,
contains: true,
},
}
for _, tc := range cases {
err := errors.Wrap(tc.wrapper, tc.wrapped)
contains := errors.Contains(err, tc.contained)
assert.Equal(t, tc.contains, contains, fmt.Sprintf("%s: expected %v to contain %v\n", tc.desc, tc.wrapper, tc.wrapped))
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
err := errors.Wrap(c.wrapper, c.wrapped)
contains := errors.Contains(err, c.contained)
assert.Equal(t, c.contains, contains)
})
}
}
func TestUnwrap(t *testing.T) {
cases := []struct {
desc string
err error
wrapper error
wrapped error
}{
{
desc: "err 1 wraped err 2",
err: errors.Wrap(err1, err2),
wrapper: err1,
wrapped: err2,
},
{
desc: "err2 wraps err1 wraps err0",
err: errors.Wrap(err2, errors.Wrap(err1, err0)),
wrapper: err2,
wrapped: errors.Wrap(err1, err0),
},
{
desc: "nil wraps nil",
err: errors.Wrap(nil, nil),
wrapper: nil,
wrapped: nil,
},
{
desc: "err0 wraps nil",
err: errors.Wrap(err0, nil),
wrapper: nil,
wrapped: err0,
},
{
desc: "nil wraps err0",
err: errors.Wrap(nil, err0),
wrapper: nil,
wrapped: nil,
},
{
desc: "nil wraps native error",
err: errors.Wrap(nil, nat),
wrapper: nil,
wrapped: nil,
},
{
desc: "native error wraps nil",
err: errors.Wrap(nat, nil),
wrapper: nil,
wrapped: nat,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
wrapper, wrapped := errors.Unwrap(c.err)
assert.Equal(t, c.wrapper, wrapper)
assert.Equal(t, c.wrapped, wrapped)
})
}
}
+8
View File
@@ -71,6 +71,10 @@ func NewSDKError(err error) SDKError {
// NewSDKErrorWithStatus returns an SDK Error setting the status code.
func NewSDKErrorWithStatus(err error, statusCode int) SDKError {
if err == nil {
return nil
}
if e, ok := err.(Error); ok {
return &sdkError{
statusCode: statusCode,
@@ -93,6 +97,10 @@ func NewSDKErrorWithStatus(err error, statusCode int) SDKError {
// Since multiple status codes can be valid, we can pass multiple status codes to the function.
// The function then checks for errors in the HTTP response.
func CheckError(resp *http.Response, expectedStatusCodes ...int) SDKError {
if resp == nil {
return nil
}
for _, expectedStatusCode := range expectedStatusCodes {
if resp.StatusCode == expectedStatusCode {
return nil
+206
View File
@@ -0,0 +1,206 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package errors_test
import (
"bytes"
"fmt"
"io"
"net/http"
"testing"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
)
var body = []byte(`{"error":"error","message":"message"}`)
func TestNewSDKError(t *testing.T) {
cases := []struct {
desc string
err error
}{
{
desc: "nil error",
err: nil,
},
{
desc: "non nil error",
err: err0,
},
{
desc: "non nil error with wrapped error",
err: errors.Wrap(err0, err1),
},
{
desc: "native error",
err: nat,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
sdk := errors.NewSDKError(c.err)
if c.err != nil {
assert.Equal(t, sdk.StatusCode(), 0)
assert.Equal(t, sdk.Error(), fmt.Sprintf("Status: %s: %s", http.StatusText(0), c.err.Error()))
}
})
}
}
func TestNewSDKErrorWithStatus(t *testing.T) {
cases := []struct {
desc string
err error
sc int
}{
{
desc: "nil error with 0 status code",
err: nil,
sc: 0,
},
{
desc: "nil error with 404 status code",
err: nil,
sc: 404,
},
{
desc: "non nil error with 0 status code",
err: err0,
sc: 0,
},
{
desc: "non nil error with 404 status code",
err: err0,
sc: 404,
},
{
desc: "non nil error with wrapped error and 0 status code",
err: errors.Wrap(err0, err1),
sc: 0,
},
{
desc: "non nil error with wrapped error and 404 status code",
err: errors.Wrap(err0, err1),
sc: 404,
},
{
desc: "native error with 0 status code",
err: nat,
sc: 0,
},
{
desc: "native error with 404 status code",
err: nat,
sc: 404,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
sdk := errors.NewSDKErrorWithStatus(c.err, c.sc)
if c.err != nil {
assert.Equal(t, sdk.StatusCode(), c.sc)
assert.Equal(t, sdk.Error(), fmt.Sprintf("Status: %s: %s", http.StatusText(c.sc), c.err.Error()))
}
})
}
}
func TestCheckError(t *testing.T) {
cases := []struct {
desc string
resp *http.Response
codes []int
err errors.SDKError
}{
{
desc: "nil response",
resp: nil,
codes: []int{http.StatusOK},
err: nil,
},
{
desc: "nil response with 404 status code",
resp: nil,
codes: []int{http.StatusNotFound},
err: nil,
},
{
desc: "valid response with 200 status code",
resp: &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
},
codes: []int{http.StatusOK},
err: nil,
},
{
desc: "valid response with 404 status code",
resp: &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewReader(body)),
},
codes: []int{http.StatusNotFound},
err: nil,
},
{
desc: "invalid response with 200 status code",
resp: &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewReader(body)),
},
codes: []int{http.StatusOK},
err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.New("message"), errors.New("error")), http.StatusNotFound),
},
{
desc: "invalid response with 404 status code",
resp: &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
},
codes: []int{http.StatusNotFound},
err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.New("message"), errors.New("error")), http.StatusOK),
},
{
desc: "valid response with 200 status code and 404 status code",
resp: &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
},
codes: []int{http.StatusOK, http.StatusNotFound},
err: nil,
},
{
desc: "error in JSON marshalling",
resp: &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewReader([]byte(`"error":`))),
},
codes: []int{http.StatusOK},
err: errors.NewSDKErrorWithStatus(errors.New("invalid character ':' after top-level value"), http.StatusNotFound),
},
{
desc: "empty error message",
resp: &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewReader([]byte(`{"error":"","message":""}`))),
},
codes: []int{http.StatusOK},
err: errors.NewSDKErrorWithStatus(errors.New(""), http.StatusNotFound),
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
sdk := errors.CheckError(c.resp, c.codes...)
assert.Equal(t, sdk, c.err)
if c.err != nil {
assert.Equal(t, sdk, c.err)
assert.Equal(t, sdk.StatusCode(), c.resp.StatusCode)
}
})
}
}