add cli tests (#274)

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-10-08 17:28:17 +03:00
committed by GitHub
parent 034547d667
commit 7ef25674c4
8 changed files with 596 additions and 12 deletions
+10 -6
View File
@@ -154,26 +154,30 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
Example: "get <report_data>",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
log.Println("Getting attestation")
cmd.Println("Getting attestation")
reportData, err := hex.DecodeString(args[0])
if err != nil {
log.Fatalf("attestation validation and verification failed with error: %s", err)
cmd.Printf("attestation validation and verification failed with error: %s", err)
return
}
if len(reportData) != agent.ReportDataSize {
log.Fatalf("report data must be a hex encoded string of length %d bytes", agent.ReportDataSize)
cmd.Printf("report data must be a hex encoded string of length %d bytes", agent.ReportDataSize)
return
}
result, err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData))
if err != nil {
log.Fatalf("Error retrieving attestation: %v", err)
cmd.Printf("Error retrieving attestation: %v", err)
return
}
if err = os.WriteFile(attestationFilePath, result, 0o644); err != nil {
log.Fatalf("Error saving attestation result: %v", err)
cmd.Printf("Error saving attestation result: %v", err)
return
}
log.Println("Attestation result retrieved and saved successfully!")
cmd.Println("Attestation result retrieved and saved successfully!")
},
}
}
+190
View File
@@ -0,0 +1,190 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"bytes"
"encoding/hex"
"fmt"
"os"
"testing"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
)
func TestNewAttestationCmd(t *testing.T) {
cli := &CLI{}
cmd := cli.NewAttestationCmd()
assert.Equal(t, "attestation [command]", cmd.Use)
assert.Equal(t, "Get and validate attestations", cmd.Short)
}
func TestNewGetAttestationCmd(t *testing.T) {
mockSDK := new(mocks.SDK)
cli := &CLI{agentSDK: mockSDK}
cmd := cli.NewGetAttestationCmd()
var buf bytes.Buffer
cmd.SetOutput(&buf)
assert.Equal(t, "get", cmd.Use)
assert.Equal(t, "Retrieve attestation information from agent. Report data expected in hex enoded string of length 64 bytes.", cmd.Short)
reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize)
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData)).Return([]byte("mock attestation"), nil)
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
err := cmd.Execute()
assert.NoError(t, err)
assert.Contains(t, buf.String(), "Attestation result retrieved and saved successfully!")
os.Remove(attestationFilePath)
}
func TestNewValidateAttestationValidationCmd(t *testing.T) {
cli := &CLI{}
cmd := cli.NewValidateAttestationValidationCmd()
assert.Equal(t, "validate", cmd.Use)
assert.Equal(t, "Validate and verify attestation information. The report is provided as a file path.", cmd.Short)
assert.Equal(t, fmt.Sprint(defaultMinimumTcb), cmd.Flag("minimum_tcb").Value.String())
assert.Equal(t, fmt.Sprint(defaultMinimumLaunchTcb), cmd.Flag("minimum_lauch_tcb").Value.String())
assert.Equal(t, fmt.Sprint(defaultGuestPolicy), cmd.Flag("guest_policy").Value.String())
assert.Equal(t, fmt.Sprint(defaultMinimumGuestSvn), cmd.Flag("minimum_guest_svn").Value.String())
assert.Equal(t, fmt.Sprint(defaultMinimumBuild), cmd.Flag("minimum_build").Value.String())
assert.Equal(t, defaultCheckCrl, cmd.Flag("check_crl").Value.String() == "true")
assert.Equal(t, fmt.Sprint(defaultTimeout), cmd.Flag("timeout").Value.String())
assert.Equal(t, fmt.Sprint(defaultMaxRetryDelay), cmd.Flag("max_retry_delay").Value.String())
}
func TestParseConfig(t *testing.T) {
cfgString = ""
err := parseConfig()
assert.NoError(t, err)
assert.NotNil(t, cfg.RootOfTrust)
assert.NotNil(t, cfg.Policy)
cfgString = `{"rootOfTrust":{"product":"test_product"},"policy":{"minimumGuestSvn":1}}`
err = parseConfig()
assert.NoError(t, err)
assert.Equal(t, "test_product", cfg.RootOfTrust.Product)
assert.Equal(t, uint32(1), cfg.Policy.MinimumGuestSvn)
cfgString = `{"invalid_json"`
err = parseConfig()
assert.Error(t, err)
}
func TestParseHashes(t *testing.T) {
trustedAuthorHashes = []string{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}
trustedIdKeyHashes = []string{"fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"}
cfg = check.Config{}
if cfg.Policy == nil {
cfg.Policy = &check.Policy{}
}
err := parseHashes()
assert.NoError(t, err)
assert.Len(t, cfg.Policy.TrustedAuthorKeyHashes, 1)
assert.Len(t, cfg.Policy.TrustedIdKeyHashes, 1)
trustedAuthorHashes = []string{"invalid_hash"}
err = parseHashes()
assert.Error(t, err)
}
func TestParseFiles(t *testing.T) {
attestationFile = "test_attestation.bin"
authorKeyFile := "test_author_key.pem"
idKeyFile := "test_id_key.pem"
err := os.WriteFile(attestationFile, []byte("test attestation"), 0o644)
assert.NoError(t, err)
err = os.WriteFile(authorKeyFile, []byte("test author key"), 0o644)
assert.NoError(t, err)
err = os.WriteFile(idKeyFile, []byte("test id key"), 0o644)
assert.NoError(t, err)
trustedAuthorKeys = []string{authorKeyFile}
trustedIdKeys = []string{idKeyFile}
err = parseFiles()
assert.NoError(t, err)
assert.Equal(t, []byte("test attestation"), attestation)
assert.Len(t, cfg.Policy.TrustedAuthorKeys, 1)
assert.Len(t, cfg.Policy.TrustedIdKeys, 1)
os.Remove(attestationFile)
os.Remove(authorKeyFile)
os.Remove(idKeyFile)
attestationFile = "non_existent_file.bin"
err = parseFiles()
assert.Error(t, err)
}
func TestParseUints(t *testing.T) {
stepping = "10"
platformInfo = "0xFF"
cfg = check.Config{}
if cfg.Policy == nil {
cfg.Policy = &check.Policy{
Product: &sevsnp.SevProduct{},
}
}
err := parseUints()
assert.NoError(t, err)
assert.Equal(t, uint32(10), cfg.Policy.Product.MachineStepping.Value)
assert.Equal(t, uint64(255), cfg.Policy.PlatformInfo.Value)
stepping = "invalid"
err = parseUints()
assert.Error(t, err)
stepping = "10"
platformInfo = "invalid"
err = parseUints()
assert.Error(t, err)
}
func TestValidateInput(t *testing.T) {
cfg = check.Config{}
if cfg.Policy == nil {
cfg.Policy = &check.Policy{}
}
if cfg.RootOfTrust == nil {
cfg.RootOfTrust = &check.RootOfTrust{}
}
cfg.Policy.ReportData = make([]byte, 64)
cfg.Policy.HostData = make([]byte, 32)
cfg.Policy.FamilyId = make([]byte, 16)
cfg.Policy.ImageId = make([]byte, 16)
cfg.Policy.ReportId = make([]byte, 32)
cfg.Policy.ReportIdMa = make([]byte, 32)
cfg.Policy.Measurement = make([]byte, 48)
cfg.Policy.ChipId = make([]byte, 64)
err := validateInput()
assert.NoError(t, err)
cfg.Policy.ReportData = make([]byte, 32)
err = validateInput()
assert.Error(t, err)
}
func TestGetBase(t *testing.T) {
assert.Equal(t, 16, getBase("0xFF"))
assert.Equal(t, 8, getBase("0o77"))
assert.Equal(t, 2, getBase("0b1010"))
assert.Equal(t, 10, getBase("123"))
}
+136
View File
@@ -0,0 +1,136 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"encoding/base64"
"encoding/json"
"os"
"testing"
"github.com/google/go-sev-guest/proto/check"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestChangeAttestationConfiguration(t *testing.T) {
tmpfile, err := os.CreateTemp("", "backend_info.json")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
initialConfig := AttestationConfiguration{
SNPPolicy: &check.Policy{
Measurement: make([]byte, measurementLength),
HostData: make([]byte, hostDataLength),
},
}
initialJSON, err := json.Marshal(initialConfig)
require.NoError(t, err)
err = os.WriteFile(tmpfile.Name(), initialJSON, 0o644)
require.NoError(t, err)
tests := []struct {
name string
base64Data string
expectedLength int
field fieldType
expectError bool
errorType error
}{
{
name: "Valid Measurement",
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength)),
expectedLength: measurementLength,
field: measurementField,
expectError: false,
},
{
name: "Valid Host Data",
base64Data: base64.StdEncoding.EncodeToString(make([]byte, hostDataLength)),
expectedLength: hostDataLength,
field: hostDataField,
expectError: false,
},
{
name: "Invalid Base64",
base64Data: "Invalid Base64",
expectedLength: measurementLength,
field: measurementField,
expectError: true,
errorType: errDecode,
},
{
name: "Invalid Data Length",
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength-1)),
expectedLength: measurementLength,
field: measurementField,
expectError: true,
errorType: errDataLength,
},
{
name: "Invalid Field Type",
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength)),
expectedLength: measurementLength,
field: fieldType(999),
expectError: true,
errorType: errBackendField,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := changeAttestationConfiguration(tmpfile.Name(), tt.base64Data, tt.expectedLength, tt.field)
if tt.expectError {
assert.Error(t, err)
assert.ErrorIs(t, err, tt.errorType)
} else {
assert.NoError(t, err)
content, err := os.ReadFile(tmpfile.Name())
require.NoError(t, err)
var config AttestationConfiguration
err = json.Unmarshal(content, &config)
require.NoError(t, err)
decodedData, _ := base64.StdEncoding.DecodeString(tt.base64Data)
if tt.field == measurementField {
assert.Equal(t, decodedData, config.SNPPolicy.Measurement)
} else if tt.field == hostDataField {
assert.Equal(t, decodedData, config.SNPPolicy.HostData)
}
}
})
}
}
func TestNewBackendCmd(t *testing.T) {
cli := &CLI{}
cmd := cli.NewBackendCmd()
assert.Equal(t, "backend [command]", cmd.Use)
assert.Equal(t, "Change backend information", cmd.Short)
assert.NotNil(t, cmd.Run)
}
func TestNewAddMeasurementCmd(t *testing.T) {
cli := &CLI{}
cmd := cli.NewAddMeasurementCmd()
assert.Equal(t, "measurement", cmd.Use)
assert.Equal(t, "Add measurement to the backend info file. The value should be in base64. The second parameter is backend_info.json file", cmd.Short)
assert.Equal(t, "measurement <measurement> <backend_info.json>", cmd.Example)
assert.NotNil(t, cmd.Run)
}
func TestNewAddHostDataCmd(t *testing.T) {
cli := &CLI{}
cmd := cli.NewAddHostDataCmd()
assert.Equal(t, "hostdata", cmd.Use)
assert.Equal(t, "Add host data to the backend info file. The value should be in base64. The second parameter is backend_info.json file", cmd.Short)
assert.Equal(t, "hostdata <host-data> <backend_info.json>", cmd.Example)
assert.NotNil(t, cmd.Run)
}
+59
View File
@@ -0,0 +1,59 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"bytes"
"os"
"path"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewCABundleCmd(t *testing.T) {
cli := &CLI{}
tempDir, err := os.MkdirTemp("", "ca-bundle-test")
assert.NoError(t, err)
defer os.RemoveAll(tempDir)
manifestContent := []byte(`{"root_of_trust": {"product": "Milan"}}`)
manifestPath := path.Join(tempDir, "manifest.json")
err = os.WriteFile(manifestPath, manifestContent, 0o644)
assert.NoError(t, err)
cmd := cli.NewCABundleCmd(tempDir)
cmd.SetArgs([]string{manifestPath})
output := &bytes.Buffer{}
cmd.SetOutput(output)
err = cmd.Execute()
assert.NoError(t, err)
expectedFilePath := path.Join(tempDir, "Milan", caBundleName)
_, err = os.Stat(expectedFilePath)
assert.NoError(t, err)
content, err := os.ReadFile(expectedFilePath)
assert.NoError(t, err)
assert.NotNil(t, content)
}
func TestSaveToFile(t *testing.T) {
tempDir, err := os.MkdirTemp("", "save-to-file-test")
assert.NoError(t, err)
defer os.RemoveAll(tempDir)
filePath := path.Join(tempDir, "test-file.txt")
content := []byte("test content")
err = saveToFile(filePath, content)
assert.NoError(t, err)
savedContent, err := os.ReadFile(filePath)
assert.NoError(t, err)
assert.Equal(t, content, savedContent)
_, err = os.Stat(filePath)
assert.NoError(t, err)
}
+3 -4
View File
@@ -3,8 +3,6 @@
package cli
import (
"log"
"github.com/spf13/cobra"
"github.com/ultravioletrs/cocos/internal"
)
@@ -20,10 +18,11 @@ func (cli *CLI) NewFileHashCmd() *cobra.Command {
hash, err := internal.ChecksumHex(path)
if err != nil {
log.Fatalf("Error computing hash: %v", err)
cmd.Printf("Error computing hash: %v", err)
return
}
log.Println("Hash of file:", hash)
cmd.Println("Hash of file:", hash)
},
}
}
+102
View File
@@ -0,0 +1,102 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"bytes"
"os"
"strings"
"testing"
"github.com/ultravioletrs/cocos/internal"
)
func TestNewFileHashCmd(t *testing.T) {
cli := &CLI{}
cmd := cli.NewFileHashCmd()
if cmd.Use != "checksum" {
t.Errorf("Expected Use to be 'checksum', got %s", cmd.Use)
}
if cmd.Short != "Compute the sha3-256 hash of a file" {
t.Errorf("Expected Short to be 'Compute the sha3-256 hash of a file', got %s", cmd.Short)
}
if cmd.Example != "checksum <file>" {
t.Errorf("Expected Example to be 'checksum <file>', got %s", cmd.Example)
}
}
func TestNewFileHashCmdRun(t *testing.T) {
cli := &CLI{}
cmd := cli.NewFileHashCmd()
content := []byte("test content")
tmpfile, err := os.CreateTemp("", "example")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(content); err != nil {
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}
var output bytes.Buffer
cmd.SetOut(&output)
cmd.SetErr(&output)
cmd.SetArgs([]string{tmpfile.Name()})
err = cmd.Execute()
if err != nil {
t.Fatalf("Error executing command: %v", err)
}
expectedHash, err := internal.ChecksumHex(tmpfile.Name())
if err != nil {
t.Fatalf("Error computing expected hash: %v", err)
}
if !strings.Contains(output.String(), expectedHash) {
t.Errorf("Expected output to contain hash %s, got %s", expectedHash, output.String())
}
}
func TestNewFileHashCmdInvalidArgs(t *testing.T) {
cli := &CLI{}
cmd := cli.NewFileHashCmd()
err := cmd.Execute()
if err == nil {
t.Error("Expected error when executing without arguments, got nil")
}
cmd.SetArgs([]string{"file1", "file2"})
err = cmd.Execute()
if err == nil {
t.Error("Expected error when executing with too many arguments, got nil")
}
}
func TestNewFileHashCmdNonExistentFile(t *testing.T) {
cli := &CLI{}
cmd := cli.NewFileHashCmd()
var output bytes.Buffer
cmd.SetOut(&output)
cmd.SetErr(&output)
cmd.SetArgs([]string{"non_existent_file.txt"})
err := cmd.Execute()
if err != nil {
t.Fatalf("Error executing command: %v", err)
}
if !strings.Contains(output.String(), "Error computing hash") {
t.Errorf("Expected output to contain 'Error computing hash', got %s", output.String())
}
}
+4 -2
View File
@@ -25,6 +25,8 @@ const (
publicKeyType = "PUBLIC KEY"
publicKeyFile = "public.pem"
privateKeyFile = "private.pem"
ECDSA = "ecdsa"
ED25519 = "ed25519"
)
var KeyType string
@@ -39,7 +41,7 @@ func (cli *CLI) NewKeysCmd() *cobra.Command {
Args: cobra.ExactArgs(0),
Run: func(cmd *cobra.Command, args []string) {
switch KeyType {
case "ecdsa":
case ECDSA:
privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
log.Fatalf("Error generating keys: %v", err)
@@ -52,7 +54,7 @@ func (cli *CLI) NewKeysCmd() *cobra.Command {
generateAndWriteKeys(privEcdsaKey, pubKeyBytes, ecdsaKeyType)
case "ed25519":
case ED25519:
pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
log.Fatalf("Error generating keys: %v", err)
+92
View File
@@ -0,0 +1,92 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"
"testing"
)
func TestNewKeysCmd(t *testing.T) {
cli := &CLI{}
cmd := cli.NewKeysCmd()
if cmd.Use != "keys" {
t.Errorf("Expected Use to be 'keys', got %s", cmd.Use)
}
if cmd.Short != "Generate a new public/private key pair" {
t.Errorf("Unexpected Short description: %s", cmd.Short)
}
}
func TestGenerateAndWriteKeys(t *testing.T) {
tests := []struct {
name string
keyType string
}{
{"RSA", "rsa"},
{"ECDSA", "ecdsa"},
{"ED25519", "ed25519"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
KeyType = tt.keyType
cmd := (&CLI{}).NewKeysCmd()
cmd.Run(cmd, []string{})
if _, err := os.Stat(privateKeyFile); os.IsNotExist(err) {
t.Errorf("Private key file was not created")
}
if _, err := os.Stat(publicKeyFile); os.IsNotExist(err) {
t.Errorf("Public key file was not created")
}
privKeyData, err := os.ReadFile(privateKeyFile)
if err != nil {
t.Fatalf("Failed to read private key file: %v", err)
}
privPem, _ := pem.Decode(privKeyData)
if privPem == nil {
t.Fatalf("Failed to decode private key PEM")
}
var privKey interface{}
switch tt.keyType {
case "rsa":
privKey, err = x509.ParsePKCS1PrivateKey(privPem.Bytes)
case "ecdsa":
privKey, err = x509.ParseECPrivateKey(privPem.Bytes)
case "ed25519":
privKey, err = x509.ParsePKCS8PrivateKey(privPem.Bytes)
}
if err != nil {
t.Fatalf("Failed to parse private key: %v", err)
}
switch tt.keyType {
case "rsa":
if _, ok := privKey.(*rsa.PrivateKey); !ok {
t.Errorf("Expected RSA private key, got %T", privKey)
}
case "ecdsa":
if _, ok := privKey.(*ecdsa.PrivateKey); !ok {
t.Errorf("Expected ECDSA private key, got %T", privKey)
}
case "ed25519":
if _, ok := privKey.(ed25519.PrivateKey); !ok {
t.Errorf("Expected ED25519 private key, got %T", privKey)
}
}
os.Remove(privateKeyFile)
os.Remove(publicKeyFile)
})
}
}