mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 07:00:25 +00:00
NOISSUE - Add Tests For AES Module on Rules Engine (#182)
* refactor(re): aes error handling and adding unit tests Added unit tests for lua bindings * fix(re): improve error handling return 2 values in case of error and 1 value for valid data * style: add license information
This commit is contained in:
@@ -6,55 +6,59 @@ package re
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidDataSize = errors.New("data is not a multiple of the block size")
|
||||
errInvalidIVSize = errors.New("size of the IV is not the same as block size")
|
||||
)
|
||||
|
||||
// AES CBC-128 DECRYPTION requires 3 data fields
|
||||
// encrypt implements AES CBC-128 ENCRYPTION which requires 3 data fields
|
||||
// 1. Key (16 bytes)
|
||||
// 2. Initialization Vector (IV) (16 bytes)
|
||||
// 3. Encrypted Data (16 bytes or length multiple a of 16)
|
||||
// The encrypted data is divided into blocks of 16 bytes (128 bits) which then operated on with the IV and Key.
|
||||
func encrypt(key []byte, iv []byte, data []byte) ([]byte, error) {
|
||||
if len(iv) != aes.BlockSize {
|
||||
return nil, errInvalidIVSize
|
||||
}
|
||||
if len(data)%aes.BlockSize != 0 {
|
||||
return nil, errInvalidDataSize
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encrypted := make([]byte, len(data))
|
||||
blockSize := block.BlockSize()
|
||||
if len(data)%blockSize != 0 {
|
||||
return nil, fmt.Errorf("payload length %d is not a multiple of AES block size %d", len(data), blockSize)
|
||||
}
|
||||
|
||||
if len(iv) != blockSize {
|
||||
return nil, fmt.Errorf("size of the IV %d is not the same as block size %d", len(iv), blockSize)
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
encrypted := make([]byte, len(data))
|
||||
mode.CryptBlocks(encrypted, data)
|
||||
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// decrypt implements AES CBC-128 DECRYPTION which requires 3 data fields
|
||||
// 1. Key (16 bytes)
|
||||
// 2. Initialization Vector (IV) (16 bytes)
|
||||
// 3. Encrypted Data (16 bytes or length multiple a of 16)
|
||||
// The encrypted data is divided into blocks of 16 bytes (128 bits) which then operated on with the IV and Key.
|
||||
func decrypt(key []byte, iv []byte, encrypted []byte) ([]byte, error) {
|
||||
if len(iv) != aes.BlockSize {
|
||||
return nil, errInvalidIVSize
|
||||
}
|
||||
if len(encrypted)%aes.BlockSize != 0 {
|
||||
return nil, errInvalidDataSize
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blockSize := block.BlockSize()
|
||||
if len(encrypted)%blockSize != 0 {
|
||||
return nil, fmt.Errorf("encrypted payload length %d is not a multiple of AES block size %d", len(encrypted), blockSize)
|
||||
}
|
||||
|
||||
if len(iv) != blockSize {
|
||||
return nil, fmt.Errorf("size of the IV %d is not the same as block size %d", len(iv), blockSize)
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
decrypted := make([]byte, len(encrypted))
|
||||
mode.CryptBlocks(decrypted, encrypted)
|
||||
return decrypted, err
|
||||
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
+354
@@ -0,0 +1,354 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package re
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncrypt(t *testing.T) {
|
||||
validKey := make([]byte, aes.BlockSize)
|
||||
validIV := make([]byte, aes.BlockSize)
|
||||
validData := make([]byte, aes.BlockSize*2) // 2 blocks
|
||||
|
||||
_, err := rand.Read(validKey)
|
||||
require.Nil(t, err, "Failed to generate valid key")
|
||||
_, err = rand.Read(validIV)
|
||||
require.Nil(t, err, "Failed to generate valid IV")
|
||||
_, err = rand.Read(validData)
|
||||
require.Nil(t, err, "Failed to generate valid data")
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
key []byte
|
||||
iv []byte
|
||||
data []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "valid encryption - single block",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
data: make([]byte, 16),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "valid encryption - multiple blocks",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
data: validData,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "valid encryption - empty data",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
data: make([]byte, 0),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid IV size - too short",
|
||||
key: validKey,
|
||||
iv: make([]byte, 8),
|
||||
data: validData,
|
||||
err: errors.New("size of the IV 8 is not the same as block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid IV size - too long",
|
||||
key: validKey,
|
||||
iv: make([]byte, 32),
|
||||
data: validData,
|
||||
err: errors.New("size of the IV 32 is not the same as block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid IV size - nil",
|
||||
key: validKey,
|
||||
iv: nil,
|
||||
data: validData,
|
||||
err: errors.New("size of the IV 0 is not the same as block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid data size - not multiple of block size",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
data: make([]byte, 15),
|
||||
err: errors.New("payload length 15 is not a multiple of AES block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid data size - odd length",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
data: make([]byte, 17),
|
||||
err: errors.New("payload length 17 is not a multiple of AES block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid key size - too short",
|
||||
key: make([]byte, 8),
|
||||
iv: validIV,
|
||||
data: validData,
|
||||
err: aes.KeySizeError(8),
|
||||
},
|
||||
{
|
||||
name: "invalid key size - nil",
|
||||
key: nil,
|
||||
iv: validIV,
|
||||
data: validData,
|
||||
err: aes.KeySizeError(0),
|
||||
},
|
||||
{
|
||||
name: "AES-192 key",
|
||||
key: make([]byte, 24),
|
||||
iv: validIV,
|
||||
data: validData,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "AES-256 key",
|
||||
key: make([]byte, 32),
|
||||
iv: validIV,
|
||||
data: validData,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := encrypt(tt.key, tt.iv, tt.data)
|
||||
if tt.err != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.err, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, len(tt.data), len(result))
|
||||
|
||||
// Ensure encrypted data is different from original (unless data is all zeros)
|
||||
if len(tt.data) > 0 && !bytes.Equal(tt.data, make([]byte, len(tt.data))) {
|
||||
assert.NotEqual(t, tt.data, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt(t *testing.T) {
|
||||
validKey := make([]byte, aes.BlockSize)
|
||||
validIV := make([]byte, aes.BlockSize)
|
||||
validData := make([]byte, aes.BlockSize*2) // 2 blocks
|
||||
|
||||
_, err := rand.Read(validKey)
|
||||
require.Nil(t, err, "Failed to generate valid key")
|
||||
_, err = rand.Read(validIV)
|
||||
require.Nil(t, err, "Failed to generate valid IV")
|
||||
_, err = rand.Read(validData)
|
||||
require.Nil(t, err, "Failed to generate valid data")
|
||||
|
||||
validEncrypted, _ := encrypt(validKey, validIV, validData)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
key []byte
|
||||
iv []byte
|
||||
encrypted []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "valid decryption - single block",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
encrypted: validEncrypted[:16],
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "valid decryption - multiple blocks",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
encrypted: validEncrypted,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "valid decryption - empty data",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
encrypted: make([]byte, 0),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid IV size - too short",
|
||||
key: validKey,
|
||||
iv: make([]byte, 8),
|
||||
encrypted: validEncrypted,
|
||||
err: errors.New("size of the IV 8 is not the same as block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid IV size - too long",
|
||||
key: validKey,
|
||||
iv: make([]byte, 32),
|
||||
encrypted: validEncrypted,
|
||||
err: errors.New("size of the IV 32 is not the same as block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid IV size - nil",
|
||||
key: validKey,
|
||||
iv: nil,
|
||||
encrypted: validEncrypted,
|
||||
err: errors.New("size of the IV 0 is not the same as block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid encrypted data size - not multiple of block size",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
encrypted: make([]byte, 15),
|
||||
err: errors.New("encrypted payload length 15 is not a multiple of AES block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid encrypted data size - odd length",
|
||||
key: validKey,
|
||||
iv: validIV,
|
||||
encrypted: make([]byte, 17),
|
||||
err: errors.New("encrypted payload length 17 is not a multiple of AES block size 16"),
|
||||
},
|
||||
{
|
||||
name: "invalid key size - too short",
|
||||
key: make([]byte, 8),
|
||||
iv: validIV,
|
||||
encrypted: validEncrypted,
|
||||
err: aes.KeySizeError(8),
|
||||
},
|
||||
{
|
||||
name: "invalid key size - nil",
|
||||
key: nil,
|
||||
iv: validIV,
|
||||
encrypted: validEncrypted,
|
||||
err: aes.KeySizeError(0),
|
||||
},
|
||||
{
|
||||
name: "AES-192 key",
|
||||
key: make([]byte, 24),
|
||||
iv: validIV,
|
||||
encrypted: validEncrypted,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "AES-256 key",
|
||||
key: make([]byte, 32),
|
||||
iv: validIV,
|
||||
encrypted: validEncrypted,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := decrypt(tt.key, tt.iv, tt.encrypted)
|
||||
if tt.err != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.err, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, len(tt.encrypted), len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptRoundTrip(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
keySize int
|
||||
dataSize int
|
||||
}{
|
||||
{
|
||||
name: "AES-128 single block",
|
||||
keySize: 16,
|
||||
dataSize: 16,
|
||||
},
|
||||
{
|
||||
name: "AES-128 multiple blocks",
|
||||
keySize: 16,
|
||||
dataSize: 64,
|
||||
},
|
||||
{
|
||||
name: "AES-192 single block",
|
||||
keySize: 24,
|
||||
dataSize: 16,
|
||||
},
|
||||
{
|
||||
name: "AES-192 multiple blocks",
|
||||
keySize: 24,
|
||||
dataSize: 48,
|
||||
},
|
||||
{
|
||||
name: "AES-256 single block",
|
||||
keySize: 32,
|
||||
dataSize: 16,
|
||||
},
|
||||
{
|
||||
name: "AES-256 multiple blocks",
|
||||
keySize: 32,
|
||||
dataSize: 80,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
keySize: 16,
|
||||
dataSize: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key := make([]byte, tt.keySize)
|
||||
iv := make([]byte, aes.BlockSize)
|
||||
originalData := make([]byte, tt.dataSize)
|
||||
|
||||
_, err := rand.Read(key)
|
||||
require.Nil(t, err, "Failed to generate valid key")
|
||||
_, err = rand.Read(iv)
|
||||
require.Nil(t, err, "Failed to generate valid IV")
|
||||
|
||||
if tt.dataSize > 0 {
|
||||
_, err = rand.Read(originalData)
|
||||
require.Nil(t, err, "Failed to generate valid data")
|
||||
}
|
||||
|
||||
encrypted, err := encrypt(key, iv, originalData)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, encrypted)
|
||||
|
||||
decrypted, err := decrypt(key, iv, encrypted)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, decrypted)
|
||||
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptWithSample(t *testing.T) {
|
||||
iv := "0907780613000704d2d2d2d2d2d2d2d2"
|
||||
ivBytes, err := hex.DecodeString(iv)
|
||||
assert.NoError(t, err, "Failed to decode IV")
|
||||
payload := "Ba56dc989e08a76f855ae12ae8B00ef13fae6ad436eBe8e03e97f17B5751c241"
|
||||
payloadBytes, err := hex.DecodeString(payload)
|
||||
assert.NoError(t, err, "Failed to decode payload")
|
||||
key := "CB6ABFAA8D2247B59127D3B839CF34B4"
|
||||
keyBytes, err := hex.DecodeString(key)
|
||||
assert.NoError(t, err, "Failed to decode key")
|
||||
expected := "2f2f0c0613760100046d27350f380c13555134022f2f2f2f2f2f2f2f2f2f2f2f"
|
||||
|
||||
decrypted, err := decrypt(keyBytes, ivBytes, payloadBytes)
|
||||
assert.NoError(t, err, "Failed to decrypt")
|
||||
assert.NotNil(t, decrypted, "Decrypted payload is nil")
|
||||
assert.Equal(t, expected, hex.EncodeToString(decrypted), "Decrypted payload does not match expected")
|
||||
}
|
||||
+24
-13
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/gob"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/absmach/magistrala/alarms"
|
||||
"github.com/absmach/senml"
|
||||
@@ -19,30 +20,39 @@ import (
|
||||
func luaEncrypt(l *lua.LState) int {
|
||||
key, iv, data, err := decodeParams(l)
|
||||
if err != nil {
|
||||
return 1
|
||||
l.Push(lua.LNil)
|
||||
l.Push(lua.LString(fmt.Sprintf("failed to decode params: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
enc, err := encrypt(key, iv, data)
|
||||
if err != nil {
|
||||
l.RaiseError("Falied to encrypt: %v", err)
|
||||
return 0
|
||||
l.Push(lua.LNil)
|
||||
l.Push(lua.LString(fmt.Sprintf("failed to encrypt: %v", err)))
|
||||
return 2
|
||||
}
|
||||
l.Push(lua.LString(hex.EncodeToString(enc)))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaDecrypt(l *lua.LState) int {
|
||||
key, iv, data, err := decodeParams(l)
|
||||
if err != nil {
|
||||
return 1
|
||||
l.Push(lua.LNil)
|
||||
l.Push(lua.LString(fmt.Sprintf("failed to decode params: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
dec, err := decrypt(key, iv, data)
|
||||
if err != nil {
|
||||
l.RaiseError("Falied to decrypt: %v", err)
|
||||
return 0
|
||||
l.Push(lua.LNil)
|
||||
l.Push(lua.LString(fmt.Sprintf("failed to decrypt: %v", err)))
|
||||
return 2
|
||||
}
|
||||
|
||||
l.Push(lua.LString(hex.EncodeToString(dec)))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -53,22 +63,20 @@ func decodeParams(l *lua.LState) (key, iv, data []byte, err error) {
|
||||
|
||||
key, err = hex.DecodeString(keyStr)
|
||||
if err != nil {
|
||||
l.RaiseError("Failed to decode key: %v", err)
|
||||
return
|
||||
return nil, nil, nil, fmt.Errorf("failed to decode key: %v", err)
|
||||
}
|
||||
|
||||
iv, err = hex.DecodeString(ivStr)
|
||||
if err != nil {
|
||||
l.RaiseError("Failed to decode IV: %v", err)
|
||||
return
|
||||
return nil, nil, nil, fmt.Errorf("failed to decode IV: %v", err)
|
||||
}
|
||||
|
||||
data, err = hex.DecodeString(dataStr)
|
||||
if err != nil {
|
||||
l.RaiseError("Failed to decode data: %v", err)
|
||||
return
|
||||
return nil, nil, nil, fmt.Errorf("failed to decode data: %v", err)
|
||||
}
|
||||
return
|
||||
|
||||
return key, iv, data, nil
|
||||
}
|
||||
|
||||
func (re *re) sendEmail(l *lua.LState) int {
|
||||
@@ -86,6 +94,7 @@ func (re *re) sendEmail(l *lua.LState) int {
|
||||
if err := re.email.SendEmailNotification(recipients, "", subject, "", "", content, "", make(map[string][]byte)); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -169,6 +178,7 @@ func (re *re) saveSenml(ctx context.Context, val interface{}, msg *messaging.Mes
|
||||
if err := re.writersPub.Publish(ctx, msg.Channel, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -190,5 +200,6 @@ func (re *re) publishChannel(ctx context.Context, val interface{}, channel, subt
|
||||
if err := re.rePubSub.Publish(ctx, channel, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,429 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package re
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
func TestDecodeParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyStr string
|
||||
ivStr string
|
||||
dataStr string
|
||||
expectedKey []byte
|
||||
expectedIV []byte
|
||||
expectedData []byte
|
||||
err bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid hex strings",
|
||||
keyStr: "0123456789abcdef0123456789abcdef", // 32 chars = 16 bytes
|
||||
ivStr: "fedcba9876543210fedcba9876543210", // 32 chars = 16 bytes
|
||||
dataStr: "deadbeefcafebabe0000111122223333", // 32 chars = 16 bytes
|
||||
expectedKey: []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef},
|
||||
expectedIV: []byte{0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10},
|
||||
expectedData: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe, 0x00, 0x00, 0x11, 0x11, 0x22, 0x22, 0x33, 0x33},
|
||||
err: false,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
keyStr: "0123456789abcdef0123456789abcdef",
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "",
|
||||
expectedKey: []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef},
|
||||
expectedIV: []byte{0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10},
|
||||
expectedData: []byte{},
|
||||
err: false,
|
||||
},
|
||||
{
|
||||
name: "invalid key hex",
|
||||
keyStr: "invalid_hex",
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "deadbeefcafebabe0000111122223333",
|
||||
err: true,
|
||||
errMsg: "failed to decode key",
|
||||
},
|
||||
{
|
||||
name: "invalid IV hex",
|
||||
keyStr: "0123456789abcdef0123456789abcdef",
|
||||
ivStr: "invalid_hex",
|
||||
dataStr: "deadbeefcafebabe0000111122223333",
|
||||
err: true,
|
||||
errMsg: "failed to decode IV",
|
||||
},
|
||||
{
|
||||
name: "invalid data hex",
|
||||
keyStr: "0123456789abcdef0123456789abcdef",
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "invalid_hex",
|
||||
err: true,
|
||||
errMsg: "failed to decode data",
|
||||
},
|
||||
{
|
||||
name: "odd length key",
|
||||
keyStr: "0123456789abcdef0123456789abcde", // 31 chars (odd)
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "deadbeefcafebabe0000111122223333",
|
||||
err: true,
|
||||
errMsg: "failed to decode key",
|
||||
},
|
||||
{
|
||||
name: "odd length IV",
|
||||
keyStr: "0123456789abcdef0123456789abcdef",
|
||||
ivStr: "fedcba9876543210fedcba987654321", // 31 chars (odd)
|
||||
dataStr: "deadbeefcafebabe0000111122223333",
|
||||
err: true,
|
||||
errMsg: "failed to decode IV",
|
||||
},
|
||||
{
|
||||
name: "odd length data",
|
||||
keyStr: "0123456789abcdef0123456789abcdef",
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "deadbeefcafebabe000011112222333", // 31 chars (odd)
|
||||
err: true,
|
||||
errMsg: "failed to decode data",
|
||||
},
|
||||
{
|
||||
name: "uppercase hex",
|
||||
keyStr: "0123456789ABCDEF0123456789ABCDEF",
|
||||
ivStr: "FEDCBA9876543210FEDCBA9876543210",
|
||||
dataStr: "DEADBEEFCAFEBABE0000111122223333",
|
||||
expectedKey: []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef},
|
||||
expectedIV: []byte{0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10},
|
||||
expectedData: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe, 0x00, 0x00, 0x11, 0x11, 0x22, 0x22, 0x33, 0x33},
|
||||
err: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
L := lua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
L.Push(lua.LString(tt.keyStr))
|
||||
L.Push(lua.LString(tt.ivStr))
|
||||
L.Push(lua.LString(tt.dataStr))
|
||||
|
||||
key, iv, data, err := decodeParams(L)
|
||||
|
||||
if tt.err {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedKey, key)
|
||||
assert.Equal(t, tt.expectedIV, iv)
|
||||
assert.Equal(t, tt.expectedData, data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLuaEncrypt(t *testing.T) {
|
||||
validKey := "0123456789abcdef0123456789abcdef" // 16 bytes
|
||||
validIV := "fedcba9876543210fedcba9876543210" // 16 bytes
|
||||
validData := "deadbeefcafebabe0000111122223333" // 16 bytes
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyStr string
|
||||
ivStr string
|
||||
dataStr string
|
||||
expectReturn int
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful encryption",
|
||||
keyStr: validKey,
|
||||
ivStr: validIV,
|
||||
dataStr: validData,
|
||||
expectReturn: 1,
|
||||
},
|
||||
{
|
||||
name: "successful encryption with empty data",
|
||||
keyStr: validKey,
|
||||
ivStr: validIV,
|
||||
dataStr: "",
|
||||
expectReturn: 1,
|
||||
},
|
||||
{
|
||||
name: "invalid key hex",
|
||||
keyStr: "invalid_hex",
|
||||
ivStr: validIV,
|
||||
dataStr: validData,
|
||||
errMsg: "failed to decode params",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid IV hex",
|
||||
keyStr: validKey,
|
||||
ivStr: "invalid_hex",
|
||||
dataStr: validData,
|
||||
errMsg: "failed to decode params",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid data hex",
|
||||
keyStr: validKey,
|
||||
ivStr: validIV,
|
||||
dataStr: "invalid_hex",
|
||||
errMsg: "failed to decode params",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid key size",
|
||||
keyStr: "0123456789abcdef", // 8 bytes, too short
|
||||
ivStr: validIV,
|
||||
dataStr: validData,
|
||||
errMsg: "failed to encrypt",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid IV size",
|
||||
keyStr: validKey,
|
||||
ivStr: "0123456789abcdef", // 8 bytes, too short
|
||||
dataStr: validData,
|
||||
errMsg: "failed to encrypt",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid data size",
|
||||
keyStr: validKey,
|
||||
ivStr: validIV,
|
||||
dataStr: "deadbeefcafebabe000011112222333", // 15 bytes, not multiple of 16
|
||||
errMsg: "failed to decode params",
|
||||
expectReturn: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
L := lua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
L.Push(lua.LString(tt.keyStr))
|
||||
L.Push(lua.LString(tt.ivStr))
|
||||
L.Push(lua.LString(tt.dataStr))
|
||||
|
||||
result := luaEncrypt(L)
|
||||
assert.Equal(t, tt.expectReturn, result)
|
||||
|
||||
if tt.expectReturn == 1 {
|
||||
encryptedHex := L.ToString(-1)
|
||||
_, err := hex.DecodeString(encryptedHex)
|
||||
assert.NoError(t, err, "Pushed value should be valid hex")
|
||||
} else {
|
||||
err := L.ToString(-1)
|
||||
assert.Contains(t, err, tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLuaDecrypt(t *testing.T) {
|
||||
validKey := "0123456789abcdef0123456789abcdef" // 16 bytes
|
||||
validIV := "fedcba9876543210fedcba9876543210" // 16 bytes
|
||||
|
||||
// Create valid encrypted data by first encrypting some data
|
||||
keyBytes, _ := hex.DecodeString(validKey)
|
||||
ivBytes, _ := hex.DecodeString(validIV)
|
||||
plainData := []byte("1234567890123456") // 16 bytes
|
||||
encryptedData, _ := encrypt(keyBytes, ivBytes, plainData)
|
||||
validEncryptedStr := hex.EncodeToString(encryptedData)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyStr string
|
||||
ivStr string
|
||||
dataStr string
|
||||
expectReturn int
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful decryption",
|
||||
keyStr: validKey,
|
||||
ivStr: validIV,
|
||||
dataStr: validEncryptedStr,
|
||||
expectReturn: 1,
|
||||
},
|
||||
{
|
||||
name: "successful decryption with empty data",
|
||||
keyStr: validKey,
|
||||
ivStr: validIV,
|
||||
dataStr: "",
|
||||
expectReturn: 1,
|
||||
},
|
||||
{
|
||||
name: "invalid key hex",
|
||||
keyStr: "invalid_hex",
|
||||
ivStr: validIV,
|
||||
dataStr: validEncryptedStr,
|
||||
errMsg: "failed to decode params",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid IV hex",
|
||||
keyStr: validKey,
|
||||
ivStr: "invalid_hex",
|
||||
dataStr: validEncryptedStr,
|
||||
errMsg: "failed to decode params",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid encrypted data hex",
|
||||
keyStr: validKey,
|
||||
ivStr: validIV,
|
||||
dataStr: "invalid_hex",
|
||||
errMsg: "failed to decode params",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid key size",
|
||||
keyStr: "0123456789abcdef", // 8 bytes, too short
|
||||
ivStr: validIV,
|
||||
dataStr: validEncryptedStr,
|
||||
errMsg: "failed to decrypt",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid IV size",
|
||||
keyStr: validKey,
|
||||
ivStr: "0123456789abcdef", // 8 bytes, too short
|
||||
dataStr: validEncryptedStr,
|
||||
errMsg: "failed to decrypt",
|
||||
expectReturn: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid encrypted data size",
|
||||
keyStr: validKey,
|
||||
ivStr: validIV,
|
||||
dataStr: "deadbeefcafebabe000011112222333", // 15 bytes, not multiple of 16
|
||||
errMsg: "failed to decode params",
|
||||
expectReturn: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
L := lua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
L.Push(lua.LString(tt.keyStr))
|
||||
L.Push(lua.LString(tt.ivStr))
|
||||
L.Push(lua.LString(tt.dataStr))
|
||||
|
||||
result := luaDecrypt(L)
|
||||
assert.Equal(t, tt.expectReturn, result)
|
||||
|
||||
if tt.expectReturn == 1 {
|
||||
decryptedHex := L.ToString(-1)
|
||||
_, err := hex.DecodeString(decryptedHex)
|
||||
assert.NoError(t, err, "Pushed value should be valid hex")
|
||||
} else {
|
||||
err := L.ToString(-1)
|
||||
assert.Contains(t, err, tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLuaDecryptWithSample(t *testing.T) {
|
||||
iv := "0907780613000704d2d2d2d2d2d2d2d2"
|
||||
payload := "Ba56dc989e08a76f855ae12ae8B00ef13fae6ad436eBe8e03e97f17B5751c241"
|
||||
key := "CB6ABFAA8D2247B59127D3B839CF34B4"
|
||||
expected := "2f2f0c0613760100046d27350f380c13555134022f2f2f2f2f2f2f2f2f2f2f2f"
|
||||
|
||||
L := lua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
L.Push(lua.LString(key))
|
||||
L.Push(lua.LString(iv))
|
||||
L.Push(lua.LString(payload))
|
||||
|
||||
result := luaDecrypt(L)
|
||||
if result != 1 {
|
||||
t.Errorf("luaDecrypt() expected 1 return value, got %d", result)
|
||||
return
|
||||
}
|
||||
|
||||
decrypted, err := hex.DecodeString(L.ToString(-1))
|
||||
require.Nil(t, err, "Failed to decode decrypted payload")
|
||||
assert.Equal(t, expected, hex.EncodeToString(decrypted), "Decrypted payload does not match expected")
|
||||
}
|
||||
|
||||
func TestLuaEncryptDecryptRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyStr string
|
||||
ivStr string
|
||||
dataStr string
|
||||
}{
|
||||
{
|
||||
name: "single block round trip",
|
||||
keyStr: "0123456789abcdef0123456789abcdef",
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "deadbeefcafebabe0000111122223333",
|
||||
},
|
||||
{
|
||||
name: "multiple block round trip",
|
||||
keyStr: "0123456789abcdef0123456789abcdef",
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "deadbeefcafebabe0000111122223333cafebabe0123456789abcdef01234567",
|
||||
},
|
||||
{
|
||||
name: "empty data round trip",
|
||||
keyStr: "0123456789abcdef0123456789abcdef",
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "",
|
||||
},
|
||||
{
|
||||
name: "AES-256 round trip",
|
||||
keyStr: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
ivStr: "fedcba9876543210fedcba9876543210",
|
||||
dataStr: "deadbeefcafebabe0000111122223333",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test encryption
|
||||
L1 := lua.NewState()
|
||||
defer L1.Close()
|
||||
|
||||
L1.Push(lua.LString(tt.keyStr))
|
||||
L1.Push(lua.LString(tt.ivStr))
|
||||
L1.Push(lua.LString(tt.dataStr))
|
||||
|
||||
result := luaEncrypt(L1)
|
||||
require.Equal(t, 1, result)
|
||||
|
||||
// Get encrypted result
|
||||
encryptedHex := L1.ToString(-1)
|
||||
|
||||
// Test decryption
|
||||
L2 := lua.NewState()
|
||||
defer L2.Close()
|
||||
|
||||
L2.Push(lua.LString(tt.keyStr))
|
||||
L2.Push(lua.LString(tt.ivStr))
|
||||
L2.Push(lua.LString(encryptedHex))
|
||||
|
||||
result = luaDecrypt(L2)
|
||||
require.Equal(t, 1, result)
|
||||
|
||||
// Verify round trip
|
||||
decryptedHex := L2.ToString(-1)
|
||||
assert.Equal(t, strings.ToLower(tt.dataStr), strings.ToLower(decryptedHex))
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user