COCOS-460 - Restore test coverage to 65% (#465)
CI / ci (push) Has been cancelled

* Implement IMAMeasurements method in agentSDK and add corresponding unit tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for NewIMAMeasurements command in CLI

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add error assertion for command execution in NewIMAMeasurements test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix nil pointer dereference in Close method and update NewCreateVMCmd logic for manager client initialization

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor file permission settings to use octal notation and improve cleanup handling in NewCreateVMCmd test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive unit tests for state machine functionality

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add mock implementation for Algorithm interface and corresponding test cases

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor file permission settings to use octal notation in TestStopComputationIntegration

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove redundant reset test cases from TestStateMachine_Reset

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix race condition in action call verification in TestStateMachine_HandleEvent

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance state machine with reset functionality and improve thread safety in event handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Improve error handling in state machine start function during tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove concurrent reset and send event test from state machine tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove error logging for Start function in transition tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add mock implementations for AgentService_IMAMeasurementsClient and Service Shutdown method; enhance progress tests for IMA measurements handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive tests for FileStorage functionality including loading, saving, and concurrent access

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance tests by adding dataset and algorithm hashes in handleRunReqChunks; improve error handling in TestFileStorage_ErrorHandling cleanup

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance TestManagerClient_Process by adding new test cases for Agent state and Disconnect requests; update setupMocks to include grpcClient

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix graceful shutdown in gRPC server by adding nil checks for health and server instances

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance TestAttestation by adding mock expectations for VTpmAttestation and Attestation methods; update service call to include platform parameter

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance gRPC Server by adding synchronization for start/stop methods; prevent multiple starts and ensure graceful shutdown

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for gRPC server methods including VM creation, removal, and info retrieval

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add tests for SEVSNP and TDX host capabilities; remove unused vsock code

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add a newline for better readability in vm_test.go

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add integration tests for gRPC client in cvm_test.go

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove unused vsock dependencies and add comprehensive unit tests for GCP attestation functions

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Skip GCP tests if credentials are not set

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add tests for error handling in attestation configuration and GCP commands

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Improve error handling in Azure VM test response writing

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Skip tests in GCP functions if credentials are not set

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive unit tests for Azure attestation provider and verifier

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for TPM functionality and improve error handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive tests for attestation functionality and improve error handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add validation for teeNonce in TeeAttestation and implement comprehensive tests for provider methods

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor error messages in TDX attestation tests for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix error message in TeeAttestation test for valid nonce case

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add MeasurementProvider mock and update mockery configuration

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add logging for product in parseUints and rename test functions for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor TestSevsnpverify to reset configuration and improve error logging

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2025-07-25 16:35:37 +03:00
committed by GitHub
parent 85a2b7a6c8
commit 4e8057f481
43 changed files with 9194 additions and 321 deletions
+245
View File
@@ -3,6 +3,7 @@
package cli
import (
"bytes"
"encoding/base64"
"encoding/json"
"os"
@@ -131,3 +132,247 @@ func TestNewAddHostDataCmd(t *testing.T) {
assert.Equal(t, "hostdata <host-data> <attestation_policy.json>", cmd.Example)
assert.NotNil(t, cmd.Run)
}
func TestChangeAttestationConfigurationFileErrors(t *testing.T) {
t.Run("File Not Found", func(t *testing.T) {
err := changeAttestationConfiguration("nonexistent.json", base64.StdEncoding.EncodeToString(make([]byte, measurementLength)), measurementLength, measurementField)
assert.Error(t, err)
assert.Contains(t, err.Error(), "error while reading the attestation policy file")
})
t.Run("Invalid JSON Content", func(t *testing.T) {
tmpfile, err := os.CreateTemp("", "invalid.json")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
err = os.WriteFile(tmpfile.Name(), []byte("invalid json"), 0o644)
require.NoError(t, err)
err = changeAttestationConfiguration(tmpfile.Name(), base64.StdEncoding.EncodeToString(make([]byte, measurementLength)), measurementLength, measurementField)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal json")
})
}
func TestNewGCPAttestationPolicy(t *testing.T) {
cli := &CLI{}
cmd := cli.NewGCPAttestationPolicy()
assert.Equal(t, "gcp", cmd.Use)
assert.Equal(t, "Get attestation policy for GCP CVM", cmd.Short)
assert.Equal(t, "gcp <bin_vtmp_attestation_report_file> <vcpu_count>", cmd.Example)
assert.NotNil(t, cmd.Run)
t.Run("File Not Found", func(t *testing.T) {
cmd.SetArgs([]string{"nonexistent.bin", "4"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err := cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error reading attestation report file")
assert.Contains(t, output, "❌")
})
t.Run("Invalid vCPU Count", func(t *testing.T) {
tmpfile, err := os.CreateTemp("", "attestation.bin")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
err = os.WriteFile(tmpfile.Name(), []byte("dummy content"), 0o644)
require.NoError(t, err)
cmd.SetArgs([]string{tmpfile.Name(), "invalid"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err = cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error converting vCPU count to integer")
assert.Contains(t, output, "❌")
})
t.Run("Invalid Attestation Data", func(t *testing.T) {
tmpfile, err := os.CreateTemp("", "attestation.bin")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
err = os.WriteFile(tmpfile.Name(), []byte("invalid protobuf data"), 0o644)
require.NoError(t, err)
cmd.SetArgs([]string{tmpfile.Name(), "4"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err = cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error unmarshaling attestation report")
assert.Contains(t, output, "❌")
})
}
func TestNewDownloadGCPOvmfFile(t *testing.T) {
cli := &CLI{}
cmd := cli.NewDownloadGCPOvmfFile()
assert.Equal(t, "download", cmd.Use)
assert.Equal(t, "Download GCP OVMF file", cmd.Short)
assert.Equal(t, "download <bin_vtmp_attestation_report_file>", cmd.Example)
assert.NotNil(t, cmd.Run)
t.Run("File Not Found", func(t *testing.T) {
cmd.SetArgs([]string{"nonexistent.bin"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err := cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error reading attestation report file")
assert.Contains(t, output, "❌")
})
t.Run("Invalid Attestation Data", func(t *testing.T) {
tmpfile, err := os.CreateTemp("", "attestation.bin")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
err = os.WriteFile(tmpfile.Name(), []byte("invalid protobuf data"), 0o644)
require.NoError(t, err)
cmd.SetArgs([]string{tmpfile.Name()})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err = cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error unmarshaling attestation report")
assert.Contains(t, output, "❌")
})
}
func TestNewAzureAttestationPolicy(t *testing.T) {
cli := &CLI{}
cmd := cli.NewAzureAttestationPolicy()
assert.Equal(t, "azure", cmd.Use)
assert.Equal(t, "Get attestation policy for Azure CVM", cmd.Short)
assert.Equal(t, "azure <azure_maa_token_file> <product_name>", cmd.Example)
assert.NotNil(t, cmd.Run)
flag := cmd.Flags().Lookup("policy")
assert.NotNil(t, flag)
assert.Equal(t, "Policy of the guest CVM", flag.Usage)
t.Run("File Not Found", func(t *testing.T) {
cmd.SetArgs([]string{"nonexistent.token", "test-product"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err := cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error reading attestation report file")
assert.Contains(t, output, "❌")
})
t.Run("Valid Token File", func(t *testing.T) {
tmpfile, err := os.CreateTemp("", "token.maa")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
err = os.WriteFile(tmpfile.Name(), []byte("dummy.token.content"), 0o644)
require.NoError(t, err)
defer os.Remove("attestation_policy.json")
cmd.SetArgs([]string{tmpfile.Name(), "test-product"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err = cmd.Execute()
assert.NoError(t, err)
})
t.Run("Custom Policy Flag", func(t *testing.T) {
tmpfile, err := os.CreateTemp("", "token.maa")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
err = os.WriteFile(tmpfile.Name(), []byte("dummy.token.content"), 0o644)
require.NoError(t, err)
cmd.SetArgs([]string{"--policy", "123456", tmpfile.Name(), "test-product"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err = cmd.Execute()
assert.NoError(t, err)
flag := cmd.Flags().Lookup("policy")
assert.NotNil(t, flag)
assert.Equal(t, "123456", flag.Value.String())
})
}
func TestCommandErrorHandling(t *testing.T) {
cli := &CLI{}
t.Run("Measurement Command Error", func(t *testing.T) {
cmd := cli.NewAddMeasurementCmd()
cmd.SetArgs([]string{"invalid-base64", "nonexistent.json"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err := cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error could not change measurement data")
assert.Contains(t, output, "❌")
})
t.Run("Host Data Command Error", func(t *testing.T) {
cmd := cli.NewAddHostDataCmd()
cmd.SetArgs([]string{"invalid-base64", "nonexistent.json"})
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err := cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "Error could not change host data")
assert.Contains(t, output, "❌")
})
}
+1
View File
@@ -331,6 +331,7 @@ func parseUints() error {
if err != nil {
return err
}
cfg.Policy.Product.MachineStepping = wrapperspb.UInt32(uint32(num))
} else {
num, err := strconv.ParseUint(stepping[2:], base, 8)
+870
View File
@@ -0,0 +1,870 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
tpmAttest "github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/proto/tpm"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation/mocks"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
)
func TestAddSEVSNPVerificationOptions(t *testing.T) {
cmd := &cobra.Command{
Use: "test",
}
result := addSEVSNPVerificationOptions(cmd)
assert.Equal(t, cmd, result)
// Check that important flags are added
flags := []string{
"host_data",
"family_id",
"image_id",
"report_id",
"report_id_ma",
"measurement",
"chip_id",
"minimum_tcb",
"minimum_lauch_tcb",
"guest_policy",
"minimum_guest_svn",
"minimum_build",
"check_crl",
"timeout",
"max_retry_delay",
"require_author_key",
"require_id_block",
"platform_info",
"minimum_version",
"trusted_author_keys",
"trusted_author_key_hashes",
"trusted_id_keys",
"trusted_id_key_hashes",
"product",
"stepping",
"CA_bundles_paths",
"CA_bundles",
}
for _, flagName := range flags {
flag := cmd.Flags().Lookup(flagName)
assert.NotNil(t, flag, "Flag %s should exist", flagName)
}
}
func TestValidateInput(t *testing.T) {
tests := []struct {
name string
setupCfg func()
expectErr bool
errMsg string
}{
{
name: "valid empty config",
setupCfg: func() {
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
},
expectErr: false,
},
{
name: "CA bundles without product name",
setupCfg: func() {
cfg = check.Config{
Policy: &check.Policy{},
RootOfTrust: &check.RootOfTrust{
CabundlePaths: []string{"test.pem"},
ProductLine: "",
},
}
},
expectErr: true,
errMsg: "product name must be set if CA bundles are provided",
},
{
name: "invalid report_data length",
setupCfg: func() {
cfg = check.Config{
Policy: &check.Policy{
ReportData: []byte("invalid"),
},
RootOfTrust: &check.RootOfTrust{},
}
},
expectErr: true,
errMsg: "report_data",
},
{
name: "invalid host_data length",
setupCfg: func() {
cfg = check.Config{
Policy: &check.Policy{
HostData: []byte("invalid"),
},
RootOfTrust: &check.RootOfTrust{},
}
},
expectErr: true,
errMsg: "host_data",
},
{
name: "invalid family_id length",
setupCfg: func() {
cfg = check.Config{
Policy: &check.Policy{
FamilyId: []byte("invalid"),
},
RootOfTrust: &check.RootOfTrust{},
}
},
expectErr: true,
errMsg: "family_id",
},
{
name: "invalid image_id length",
setupCfg: func() {
cfg = check.Config{
Policy: &check.Policy{
ImageId: []byte("invalid"),
},
RootOfTrust: &check.RootOfTrust{},
}
},
expectErr: true,
errMsg: "image_id",
},
{
name: "invalid trusted author key hash",
setupCfg: func() {
cfg = check.Config{
Policy: &check.Policy{
TrustedAuthorKeyHashes: [][]byte{[]byte("invalid")},
},
RootOfTrust: &check.RootOfTrust{},
}
},
expectErr: true,
errMsg: "trusted_author_key_hash",
},
{
name: "invalid trusted id key hash",
setupCfg: func() {
cfg = check.Config{
Policy: &check.Policy{
TrustedIdKeyHashes: [][]byte{[]byte("invalid")},
},
RootOfTrust: &check.RootOfTrust{},
}
},
expectErr: true,
errMsg: "trusted_id_key_hash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupCfg()
err := validateInput()
if tt.expectErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
}
})
}
}
func TestParseTrustedKeys(t *testing.T) {
tempDir := t.TempDir()
authorKeyFile := filepath.Join(tempDir, "author.pem")
idKeyFile := filepath.Join(tempDir, "id.pem")
nonExistentFile := filepath.Join(tempDir, "nonexistent.pem")
authorKeyContent := "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAOI..."
idKeyContent := "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAOI..."
require.NoError(t, os.WriteFile(authorKeyFile, []byte(authorKeyContent), 0o644))
require.NoError(t, os.WriteFile(idKeyFile, []byte(idKeyContent), 0o644))
tests := []struct {
name string
trustedAuthorKeys []string
trustedIdKeys []string
expectErr bool
}{
{
name: "valid files",
trustedAuthorKeys: []string{authorKeyFile},
trustedIdKeys: []string{idKeyFile},
expectErr: false,
},
{
name: "nonexistent author key file",
trustedAuthorKeys: []string{nonExistentFile},
trustedIdKeys: []string{},
expectErr: true,
},
{
name: "nonexistent id key file",
trustedAuthorKeys: []string{},
trustedIdKeys: []string{nonExistentFile},
expectErr: true,
},
{
name: "empty file lists",
trustedAuthorKeys: []string{},
trustedIdKeys: []string{},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
trustedAuthorKeys = tt.trustedAuthorKeys
trustedIdKeys = tt.trustedIdKeys
err := parseTrustedKeys()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if len(tt.trustedAuthorKeys) > 0 {
assert.Len(t, cfg.Policy.TrustedAuthorKeys, len(tt.trustedAuthorKeys))
assert.Equal(t, []byte(authorKeyContent), cfg.Policy.TrustedAuthorKeys[0])
}
if len(tt.trustedIdKeys) > 0 {
assert.Len(t, cfg.Policy.TrustedIdKeys, len(tt.trustedIdKeys))
assert.Equal(t, []byte(idKeyContent), cfg.Policy.TrustedIdKeys[0])
}
}
})
}
}
func TestParseUints(t *testing.T) {
tests := []struct {
name string
stepping string
platformInfo string
expectErr bool
expectedStep *uint32
expectedPlatform *uint64
}{
{
name: "empty values",
stepping: "",
platformInfo: "",
expectErr: false,
},
{
name: "decimal values",
stepping: "5",
platformInfo: "10",
expectErr: false,
expectedStep: uint32Ptr(5),
expectedPlatform: uint64Ptr(10),
},
{
name: "hex values",
stepping: "0x5",
platformInfo: "0xa",
expectErr: false,
expectedStep: uint32Ptr(5),
expectedPlatform: uint64Ptr(10),
},
{
name: "octal values",
stepping: "0o7",
platformInfo: "0o12",
expectErr: false,
expectedStep: uint32Ptr(7),
expectedPlatform: uint64Ptr(10),
},
{
name: "binary values",
stepping: "0b101",
platformInfo: "0b1010",
expectErr: false,
expectedStep: uint32Ptr(5),
expectedPlatform: uint64Ptr(10),
},
{
name: "invalid stepping",
stepping: "invalid",
platformInfo: "",
expectErr: true,
},
{
name: "invalid platform info",
stepping: "",
platformInfo: "invalid",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg = check.Config{Policy: &check.Policy{Product: &sevsnp.SevProduct{}}, RootOfTrust: &check.RootOfTrust{}}
stepping = tt.stepping
platformInfo = tt.platformInfo
err := parseUints()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.expectedStep != nil {
assert.Equal(t, *tt.expectedStep, cfg.Policy.Product.MachineStepping.Value)
}
if tt.expectedPlatform != nil {
assert.Equal(t, *tt.expectedPlatform, cfg.Policy.PlatformInfo.Value)
}
}
})
}
}
func TestGetBase(t *testing.T) {
tests := []struct {
input string
expected int
}{
{"0x10", 16},
{"0o10", 8},
{"0b10", 2},
{"10", 10},
{"", 10},
{"abc", 10},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := getBase(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseConfig(t *testing.T) {
tempDir := t.TempDir()
validConfig := map[string]interface{}{
"rootOfTrust": map[string]interface{}{
"product": "test_product",
"cabundlePaths": []string{"test_path"},
"cabundles": []string{"test_bundle"},
"checkCrl": true,
"disallowNetwork": true,
},
"policy": map[string]interface{}{
"minimumGuestSvn": 1,
"policy": "1",
"minimumBuild": 1,
"minimumVersion": "0.90",
"requireAuthorKey": true,
"requireIdBlock": true,
},
}
tests := []struct {
name string
setupConfig func() string
expectErr bool
}{
{
name: "empty config string",
setupConfig: func() string {
return ""
},
expectErr: false,
},
{
name: "valid config file",
setupConfig: func() string {
configFile := filepath.Join(tempDir, "valid_config.json")
configBytes, err := json.Marshal(validConfig)
assert.NoError(t, err)
if err := os.WriteFile(configFile, configBytes, 0o644); err != nil {
t.Errorf("failed to write config file: %v", err)
}
return configFile
},
expectErr: false,
},
{
name: "nonexistent config file",
setupConfig: func() string {
return filepath.Join(tempDir, "nonexistent.json")
},
expectErr: true,
},
{
name: "invalid JSON config",
setupConfig: func() string {
configFile := filepath.Join(tempDir, "invalid_config.json")
if err := os.WriteFile(configFile, []byte("invalid json"), 0o644); err != nil {
t.Errorf("failed to write invalid config file: %v", err)
}
return configFile
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
cfgString = tt.setupConfig()
err := parseConfig()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, cfg.Policy)
assert.NotNil(t, cfg.RootOfTrust)
}
})
}
}
func TestParseHashes(t *testing.T) {
tests := []struct {
name string
trustedAuthorHashes []string
trustedIdKeyHashes []string
expectErr bool
}{
{
name: "valid hashes",
trustedAuthorHashes: []string{"deadbeef", "cafebabe"},
trustedIdKeyHashes: []string{"12345678", "87654321"},
expectErr: false,
},
{
name: "empty hashes",
trustedAuthorHashes: []string{},
trustedIdKeyHashes: []string{},
expectErr: false,
},
{
name: "invalid author hash",
trustedAuthorHashes: []string{"invalid_hex"},
trustedIdKeyHashes: []string{},
expectErr: true,
},
{
name: "invalid id key hash",
trustedAuthorHashes: []string{},
trustedIdKeyHashes: []string{"invalid_hex"},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
trustedAuthorHashes = tt.trustedAuthorHashes
trustedIdKeyHashes = tt.trustedIdKeyHashes
err := parseHashes()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, cfg.Policy.TrustedAuthorKeyHashes, len(tt.trustedAuthorHashes))
assert.Len(t, cfg.Policy.TrustedIdKeyHashes, len(tt.trustedIdKeyHashes))
for i, hash := range tt.trustedAuthorHashes {
expected, _ := hex.DecodeString(hash)
assert.Equal(t, expected, cfg.Policy.TrustedAuthorKeyHashes[i])
}
for i, hash := range tt.trustedIdKeyHashes {
expected, _ := hex.DecodeString(hash)
assert.Equal(t, expected, cfg.Policy.TrustedIdKeyHashes[i])
}
}
})
}
}
func TestParseAttestationFile(t *testing.T) {
tempDir := t.TempDir()
binaryFile := filepath.Join(tempDir, "attestation.bin")
jsonFile := filepath.Join(tempDir, "attestation.json")
binaryData := make([]byte, 1024)
for i := range binaryData {
binaryData[i] = byte(i % 256)
}
jsonData := &sevsnp.Attestation{
Report: &sevsnp.Report{
FamilyId: make([]byte, 16),
ImageId: make([]byte, 16),
ReportData: make([]byte, 64),
Measurement: make([]byte, 48),
HostData: make([]byte, 32),
IdKeyDigest: make([]byte, 48),
AuthorKeyDigest: make([]byte, 48),
ReportId: make([]byte, 32),
ReportIdMa: make([]byte, 32),
ChipId: make([]byte, 64),
Signature: make([]byte, 512),
},
}
jsonBytes, err := json.Marshal(jsonData)
require.NoError(t, err)
require.NoError(t, os.WriteFile(binaryFile, binaryData, 0o644))
require.NoError(t, os.WriteFile(jsonFile, jsonBytes, 0o644))
tests := []struct {
name string
attestationFile string
expectErr bool
}{
{
name: "valid binary file",
attestationFile: binaryFile,
expectErr: false,
},
{
name: "valid JSON file",
attestationFile: jsonFile,
expectErr: false,
},
{
name: "nonexistent file",
attestationFile: filepath.Join(tempDir, "nonexistent.bin"),
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
attestationFile = tt.attestationFile
err := parseAttestationFile()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, attestationRaw)
assert.NotEmpty(t, attestationRaw)
}
})
}
}
func TestSevsnpverify(t *testing.T) {
trustedAuthorHashes = []string{}
trustedIdKeyHashes = []string{}
stepping = ""
platformInfo = ""
tempDir := t.TempDir()
cfg = check.Config{Policy: &check.Policy{Product: &sevsnp.SevProduct{}}, RootOfTrust: &check.RootOfTrust{}}
attestationFile := filepath.Join(tempDir, "attestation.bin")
attestationData := make([]byte, abi.ReportSize+100)
for i := range attestationData {
attestationData[i] = byte(i % 256)
}
require.NoError(t, os.WriteFile(attestationFile, attestationData, 0o644))
tests := []struct {
name string
args []string
setupMock func(*mocks.Verifier)
expectErr bool
expectedMsg string
}{
{
name: "successful verification",
args: []string{attestationFile},
setupMock: func(m *mocks.Verifier) {
m.On("VerifTeeAttestation", mock.Anything, mock.Anything).Return(nil)
},
expectErr: false,
expectedMsg: "Attestation validation and verification is successful!",
},
{
name: "verification failure",
args: []string{attestationFile},
setupMock: func(m *mocks.Verifier) {
m.On("VerifTeeAttestation", mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
},
expectErr: true,
expectedMsg: "attestation validation and verification failed",
},
{
name: "nonexistent file",
args: []string{filepath.Join(tempDir, "nonexistent.bin")},
setupMock: func(m *mocks.Verifier) {},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfgString = ""
mockVerifier := new(mocks.Verifier)
tt.setupMock(mockVerifier)
var output bytes.Buffer
cmd := &cobra.Command{}
cmd.SetOut(&output)
err := sevsnpverify(cmd, mockVerifier, tt.args)
fmt.Println("error1", err)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.expectedMsg != "" {
assert.Contains(t, output.String(), tt.expectedMsg)
}
}
mockVerifier.AssertExpectations(t)
})
}
}
func TestReturnvTPMAttestation(t *testing.T) {
tempDir := t.TempDir()
attestation := &tpmAttest.Attestation{
Quotes: []*tpm.Quote{
{
Quote: []byte("test quote"),
RawSig: []byte("test signature"),
},
},
}
binaryData, err := proto.Marshal(attestation)
require.NoError(t, err)
binaryFile := filepath.Join(tempDir, "attestation.pb")
require.NoError(t, os.WriteFile(binaryFile, binaryData, 0o644))
textData, err := prototext.Marshal(attestation)
require.NoError(t, err)
textFile := filepath.Join(tempDir, "attestation.txtpb")
require.NoError(t, os.WriteFile(textFile, textData, 0o644))
tests := []struct {
name string
args []string
format string
expectErr bool
}{
{
name: "binary protobuf format",
args: []string{binaryFile},
format: FormatBinaryPB,
expectErr: false,
},
{
name: "text protobuf format",
args: []string{textFile},
format: FormatTextProto,
expectErr: false,
},
{
name: "invalid format",
args: []string{binaryFile},
format: "invalid",
expectErr: true,
},
{
name: "nonexistent file",
args: []string{filepath.Join(tempDir, "nonexistent.pb")},
format: FormatBinaryPB,
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
format = tt.format
result, err := returnvTPMAttestation(tt.args)
if tt.expectErr {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result)
}
})
}
}
func TestVtpmSevSnpverify(t *testing.T) {
stepping = ""
platformInfo = ""
trustedAuthorHashes = []string{}
trustedIdKeyHashes = []string{}
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
tempDir := t.TempDir()
attestation := &tpmAttest.Attestation{
Quotes: []*tpm.Quote{
{
Quote: []byte("test quote"),
RawSig: []byte("test signature"),
},
},
}
binaryData, err := proto.Marshal(attestation)
require.NoError(t, err)
attestationFile := filepath.Join(tempDir, "vtpm_attestation.pb")
require.NoError(t, os.WriteFile(attestationFile, binaryData, 0o644))
tests := []struct {
name string
args []string
setupMock func(*mocks.Verifier)
expectErr bool
}{
{
name: "successful verification",
args: []string{attestationFile},
setupMock: func(m *mocks.Verifier) {
m.On("VerifyAttestation", mock.Anything, mock.Anything, mock.Anything).Return(nil)
},
expectErr: false,
},
{
name: "verification failure",
args: []string{attestationFile},
setupMock: func(m *mocks.Verifier) {
m.On("VerifyAttestation", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
cfgString = ""
format = FormatBinaryPB
mockVerifier := new(mocks.Verifier)
tt.setupMock(mockVerifier)
err := vtpmSevSnpverify(tt.args, mockVerifier)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
mockVerifier.AssertExpectations(t)
})
}
}
func TestVtpmverify(t *testing.T) {
tempDir := t.TempDir()
attestation := &tpmAttest.Attestation{
Quotes: []*tpm.Quote{
{
Quote: []byte("test quote"),
RawSig: []byte("test signature"),
},
},
}
binaryData, err := proto.Marshal(attestation)
require.NoError(t, err)
attestationFile := filepath.Join(tempDir, "vtpm_attestation.pb")
require.NoError(t, os.WriteFile(attestationFile, binaryData, 0o644))
tests := []struct {
name string
args []string
setupMock func(*mocks.Verifier)
expectErr bool
}{
{
name: "successful verification",
args: []string{attestationFile},
setupMock: func(m *mocks.Verifier) {
m.On("VerifVTpmAttestation", mock.Anything, mock.Anything).Return(nil)
},
expectErr: false,
},
{
name: "verification failure",
args: []string{attestationFile},
setupMock: func(m *mocks.Verifier) {
m.On("VerifVTpmAttestation", mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
format = FormatBinaryPB
mockVerifier := new(mocks.Verifier)
tt.setupMock(mockVerifier)
err := vtpmverify(tt.args, mockVerifier)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
mockVerifier.AssertExpectations(t)
})
}
}
func uint32Ptr(v uint32) *uint32 {
return &v
}
func uint64Ptr(v uint64) *uint64 {
return &v
}
+694 -23
View File
@@ -4,10 +4,12 @@ package cli
import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/absmach/magistrala/pkg/errors"
@@ -18,6 +20,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
mmocks "github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
@@ -311,26 +314,12 @@ func TestNewValidateAttestationValidationCmd(t *testing.T) {
})
}
type MockMeasurement struct {
mock.Mock
}
func (m *MockMeasurement) Run(igvmBinaryPath string) ([]byte, error) {
args := m.Called(igvmBinaryPath)
return nil, args.Error(0)
}
func (m *MockMeasurement) Stop() error {
args := m.Called()
return args.Error(0)
}
func TestNewMeasureCmd_RunSuccess(t *testing.T) {
cliInstance := &CLI{}
mockMeasurement := new(MockMeasurement)
mockMeasurement := new(mmocks.MeasurementProvider)
cliInstance.measurement = mockMeasurement
mockMeasurement.On("Run", "testfile.igvm").Return(nil)
mockMeasurement.On("Run", "testfile.igvm").Return([]byte{}, nil)
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
buf := new(bytes.Buffer)
@@ -346,11 +335,11 @@ func TestNewMeasureCmd_RunSuccess(t *testing.T) {
func TestNewMeasureCmd_RunError(t *testing.T) {
cliInstance := &CLI{}
mockMeasurement := new(MockMeasurement)
mockMeasurement := new(mmocks.MeasurementProvider)
cliInstance.measurement = mockMeasurement
expectedError := errors.New("mocked measurement error")
mockMeasurement.On("Run", "testfile.igvm").Return(expectedError)
mockMeasurement.On("Run", "testfile.igvm").Return([]byte{}, expectedError)
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
@@ -366,7 +355,7 @@ func TestNewMeasureCmd_RunError(t *testing.T) {
mockMeasurement.AssertExpectations(t)
}
func TestParseConfig(t *testing.T) {
func TestParseConfig1(t *testing.T) {
tmpfile, err := os.CreateTemp("", "attestation_policy.json")
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
@@ -393,7 +382,7 @@ func TestParseConfig(t *testing.T) {
assert.Error(t, err)
}
func TestParseHashes(t *testing.T) {
func TestParseHashes1(t *testing.T) {
trustedAuthorHashes = []string{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}
trustedIdKeyHashes = []string{"fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"}
@@ -444,7 +433,7 @@ func TestParseFiles(t *testing.T) {
assert.Error(t, err)
}
func TestParseUints(t *testing.T) {
func TestParseUints1(t *testing.T) {
stepping = "10"
platformInfo = "0xFF"
@@ -469,7 +458,7 @@ func TestParseUints(t *testing.T) {
assert.Error(t, err)
}
func TestValidateInput(t *testing.T) {
func TestValidateInput1(t *testing.T) {
cfg = check.Config{}
if cfg.Policy == nil {
cfg.Policy = &check.Policy{}
@@ -494,7 +483,7 @@ func TestValidateInput(t *testing.T) {
assert.Error(t, err)
}
func TestGetBase(t *testing.T) {
func TestGetBase1(t *testing.T) {
assert.Equal(t, 16, getBase("0xFF"))
assert.Equal(t, 8, getBase("0o77"))
assert.Equal(t, 2, getBase("0b1010"))
@@ -716,3 +705,685 @@ func TestDecodeJWTToJSON(t *testing.T) {
})
}
}
func setupTestEnvironment() func() {
originalMode := mode
originalCfgString := cfgString
originalTimeout := timeout
originalMaxRetryDelay := maxRetryDelay
originalPlatformInfo := platformInfo
originalStepping := stepping
originalTrustedAuthorKeys := trustedAuthorKeys
originalTrustedAuthorHashes := trustedAuthorHashes
originalTrustedIdKeys := trustedIdKeys
originalTrustedIdKeyHashes := trustedIdKeyHashes
originalAttestationFile := attestationFile
originalAttestationRaw := attestationRaw
originalOutput := output
originalNonce := nonce
originalFormat := format
originalTeeNonce := teeNonce
originalTokenNonce := tokenNonce
originalGetTextProtoAttestationReport := getTextProtoAttestationReport
originalGetAzureTokenJWT := getAzureTokenJWT
originalCloud := cloud
originalReportData := reportData
originalCheckCrl := checkCrl
mode = ""
cfgString = ""
timeout = 0
maxRetryDelay = 0
platformInfo = ""
stepping = ""
trustedAuthorKeys = []string{}
trustedAuthorHashes = []string{}
trustedIdKeys = []string{}
trustedIdKeyHashes = []string{}
attestationFile = ""
attestationRaw = []byte{}
output = ""
nonce = []byte{}
format = ""
teeNonce = []byte{}
tokenNonce = []byte{}
getTextProtoAttestationReport = false
getAzureTokenJWT = false
cloud = ""
reportData = []byte{}
checkCrl = false
return func() {
mode = originalMode
cfgString = originalCfgString
timeout = originalTimeout
maxRetryDelay = originalMaxRetryDelay
platformInfo = originalPlatformInfo
stepping = originalStepping
trustedAuthorKeys = originalTrustedAuthorKeys
trustedAuthorHashes = originalTrustedAuthorHashes
trustedIdKeys = originalTrustedIdKeys
trustedIdKeyHashes = originalTrustedIdKeyHashes
attestationFile = originalAttestationFile
attestationRaw = originalAttestationRaw
output = originalOutput
nonce = originalNonce
format = originalFormat
teeNonce = originalTeeNonce
tokenNonce = originalTokenNonce
getTextProtoAttestationReport = originalGetTextProtoAttestationReport
getAzureTokenJWT = originalGetAzureTokenJWT
cloud = originalCloud
reportData = originalReportData
checkCrl = originalCheckCrl
}
}
func createTempFile(t *testing.T, content []byte) string {
tmpfile, err := os.CreateTemp("", "test_*.bin")
require.NoError(t, err)
defer tmpfile.Close()
_, err = tmpfile.Write(content)
require.NoError(t, err)
return tmpfile.Name()
}
func TestNewAttestationCmdEdgeCases(t *testing.T) {
tests := []struct {
name string
args []string
expectedOutput string
hasSubcommands bool
}{
{
name: "no arguments shows help",
args: []string{},
expectedOutput: "Get and validate attestations",
hasSubcommands: true,
},
{
name: "help flag shows usage",
args: []string{"--help"},
expectedOutput: "Get and validate attestations",
hasSubcommands: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSDK := new(mocks.SDK)
cli := &CLI{agentSDK: mockSDK}
cmd := cli.NewAttestationCmd()
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
cmd.SetArgs(tt.args)
err := cmd.Execute()
assert.NoError(t, err)
output := buf.String()
assert.Contains(t, output, tt.expectedOutput)
if tt.hasSubcommands {
assert.Contains(t, output, "Get and validate attestations")
assert.Contains(t, output, "Usage:")
assert.Contains(t, output, "Flags:")
}
})
}
}
func TestGetAttestationCmdEdgeCases(t *testing.T) {
defer setupTestEnvironment()()
testCases := []struct {
name string
args []string
setupMock func(*mocks.SDK)
expectedErr string
expectedOut string
}{
{
name: "no arguments provided",
args: []string{},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "accepts 1 arg(s), received 0",
},
{
name: "too many arguments",
args: []string{"snp", "extra"},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "accepts 1 arg(s), received 2",
},
{
name: "invalid attestation type",
args: []string{"invalid-type"},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "Bad attestation type",
},
{
name: "SNP with missing TEE nonce",
args: []string{"snp"},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "TEE nonce must be defined for SEV-SNP attestation",
},
{
name: "vTPM with missing nonce",
args: []string{"vtpm"},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "vTPM nonce must be defined for vTPM attestation",
},
{
name: "Azure token with missing token nonce",
args: []string{"azure-token"},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "Token nonce must be defined for Azure attestation",
},
{
name: "TEE nonce too large",
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce+1))},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes",
},
{
name: "vTPM nonce too large",
args: []string{"vtpm", "--vtpm", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce+1))},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
},
{
name: "Token nonce too large",
args: []string{"azure-token", "--token", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce+1))},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
},
{
name: "successful TDX attestation",
args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
setupMock: func(sdk *mocks.SDK) {
sdk.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil).Run(func(args mock.Arguments) {
if _, err := args.Get(4).(*os.File).Write([]byte("mock tdx attestation")); err != nil {
t.Fatalf("Failed to write to attestation file: %v", err)
}
})
},
expectedOut: "Fetching TDX attestation report",
},
{
name: "file creation error",
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "Error creating attestation file",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
os.Remove(attestationFilePath)
os.Remove(azureAttestResultFilePath)
os.Remove(azureAttestTokenFilePath)
defer func() {
os.Remove(attestationFilePath)
os.Remove(azureAttestResultFilePath)
os.Remove(azureAttestTokenFilePath)
}()
mockSDK := new(mocks.SDK)
cli := &CLI{agentSDK: mockSDK}
tc.setupMock(mockSDK)
if tc.name == "file creation error" {
err := os.Mkdir(attestationFilePath, 0o755)
require.NoError(t, err)
defer os.Remove(attestationFilePath)
}
cmd := cli.NewGetAttestationCmd()
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
cmd.SetArgs(tc.args)
err := cmd.Execute()
output := buf.String()
if tc.expectedErr != "" {
assert.Contains(t, output, tc.expectedErr)
} else {
assert.NoError(t, err)
if tc.expectedOut != "" {
assert.Contains(t, output, tc.expectedOut)
}
}
})
}
}
func TestFileOperations(t *testing.T) {
defer setupTestEnvironment()()
t.Run("openInputFile", func(t *testing.T) {
attestationFile = ""
reader, err := openInputFile()
assert.Error(t, err)
assert.Equal(t, errEmptyFile, err)
assert.Nil(t, reader)
tempFile := createTempFile(t, []byte("test content"))
defer os.Remove(tempFile)
attestationFile = tempFile
reader, err = openInputFile()
assert.NoError(t, err)
assert.NotNil(t, reader)
if file, ok := reader.(*os.File); ok {
file.Close()
}
attestationFile = "non-existent-file.bin"
reader, err = openInputFile()
assert.Error(t, err)
assert.Nil(t, reader)
})
t.Run("createOutputFile", func(t *testing.T) {
output = ""
writer, err := createOutputFile()
assert.NoError(t, err)
assert.Equal(t, os.Stdout, writer)
tempDir := t.TempDir()
output = filepath.Join(tempDir, "test_output.txt")
writer, err = createOutputFile()
assert.NoError(t, err)
assert.NotNil(t, writer)
if file, ok := writer.(*os.File); ok {
file.Close()
}
output = "/invalid/path/file.txt"
writer, err = createOutputFile()
assert.Error(t, err)
assert.Nil(t, writer)
})
}
func TestValidationFunctions(t *testing.T) {
t.Run("validateFieldLength", func(t *testing.T) {
tests := []struct {
name string
fieldName string
field []byte
expectedLength int
expectError bool
}{
{
name: "nil field",
fieldName: "test",
field: nil,
expectedLength: 32,
expectError: false,
},
{
name: "correct length",
fieldName: "test",
field: make([]byte, 32),
expectedLength: 32,
expectError: false,
},
{
name: "incorrect length",
fieldName: "test",
field: make([]byte, 16),
expectedLength: 32,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateFieldLength(tt.fieldName, tt.field, tt.expectedLength)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.fieldName)
} else {
assert.NoError(t, err)
}
})
}
})
}
func TestDecodeJWTToJSONEdgeCases(t *testing.T) {
tests := []struct {
name string
input []byte
expected string
hasError bool
}{
{
name: "empty input",
input: []byte(""),
hasError: true,
},
{
name: "single part",
input: []byte("onlyonepart"),
hasError: true,
},
{
name: "invalid base64 in header",
input: []byte("invalid@base64.validpart"),
hasError: true,
},
{
name: "invalid base64 in payload",
input: []byte("eyJhbGciOiJIUzI1NiJ9.invalid@base64"),
hasError: true,
},
{
name: "invalid JSON in header",
input: []byte("bm90anNvbg.eyJzdWIiOiJ0ZXN0In0"),
hasError: true,
},
{
name: "invalid JSON in payload",
input: []byte("eyJhbGciOiJIUzI1NiJ9.bm90anNvbg"),
hasError: true,
},
{
name: "valid JWT with padding",
input: []byte("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature"),
expected: `{
"header": {
"alg": "HS256"
},
"payload": {
"sub": "test"
}
}`,
hasError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := decodeJWTToJSON(tt.input)
if tt.hasError {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
if tt.expected != "" {
assert.JSONEq(t, tt.expected, string(result))
}
}
})
}
}
func TestMeasureCmdEdgeCases(t *testing.T) {
tests := []struct {
name string
args []string
mockSetup func(*mmocks.MeasurementProvider)
expectedError string
expectedOut string
}{
{
name: "no arguments",
args: []string{},
mockSetup: func(m *mmocks.MeasurementProvider) {
},
expectedError: "requires at least 1 arg(s), only received 0",
},
{
name: "single line output success",
args: []string{"test.igvm"},
mockSetup: func(m *mmocks.MeasurementProvider) {
m.On("Run", "test.igvm").Return([]byte("ABCDEF123456"), nil)
},
expectedOut: "",
},
{
name: "multi-line output error",
args: []string{"test.igvm"},
mockSetup: func(m *mmocks.MeasurementProvider) {
m.On("Run", "test.igvm").Return([]byte("line1\nline2\nERROR: something went wrong"), nil)
},
expectedError: "ERROR: something went wrong",
},
{
name: "measurement run error",
args: []string{"test.igvm"},
mockSetup: func(m *mmocks.MeasurementProvider) {
m.On("Run", "test.igvm").Return(nil, errors.New("measurement failed"))
},
expectedError: "measurement failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockMeasurement := new(mmocks.MeasurementProvider)
tt.mockSetup(mockMeasurement)
cli := &CLI{measurement: mockMeasurement}
cmd := cli.NewMeasureCmd("fake_binary_path")
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
cmd.SetArgs(tt.args)
err := cmd.Execute()
if tt.expectedError != "" {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.expectedError)
} else {
assert.NoError(t, err)
if tt.expectedOut != "" {
assert.Contains(t, buf.String(), tt.expectedOut)
}
}
mockMeasurement.AssertExpectations(t)
})
}
}
func TestValidateAttestationValidationCmdPreRunE(t *testing.T) {
tests := []struct {
name string
args []string
flags map[string]string
expectedErr string
}{
{
name: "no file path provided",
args: []string{},
flags: map[string]string{"mode": "snp"},
expectedErr: "please pass the attestation report file path",
},
{
name: "multiple file paths",
args: []string{"file1.bin", "file2.bin"},
flags: map[string]string{"mode": "snp"},
expectedErr: "please pass the attestation report file path",
},
{
name: "unknown mode",
args: []string{"test.bin"},
flags: map[string]string{"mode": "unknown"},
expectedErr: "unknown mode: unknown",
},
{
name: "SNP mode missing report_data",
args: []string{"test.bin"},
flags: map[string]string{"mode": "snp"},
expectedErr: "",
},
{
name: "SNP mode missing product",
args: []string{"test.bin"},
flags: map[string]string{"mode": "snp", "report_data": "123"},
expectedErr: "",
},
{
name: "vTPM mode missing nonce",
args: []string{"test.bin"},
flags: map[string]string{"mode": "vtpm"},
expectedErr: "",
},
{
name: "SNP-vTPM mode missing required flags",
args: []string{"test.bin"},
flags: map[string]string{"mode": "snp-vtpm"},
expectedErr: "",
},
{
name: "TDX mode missing report_data",
args: []string{"test.bin"},
flags: map[string]string{"mode": "tdx"},
expectedErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cli := &CLI{}
cmd := cli.NewValidateAttestationValidationCmd()
for key, value := range tt.flags {
if err := cmd.Flags().Set(key, value); err != nil {
}
}
err := cmd.PreRunE(cmd, tt.args)
if tt.expectedErr != "" {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.expectedErr)
} else {
assert.NoError(t, err)
}
})
}
}
func TestCloudProviderConfigurations(t *testing.T) {
defer setupTestEnvironment()()
tests := []struct {
name string
cloud string
expectedType string
}{
{
name: "none cloud provider",
cloud: CCNone,
expectedType: "vtpm",
},
{
name: "azure cloud provider",
cloud: CCAzure,
expectedType: "azure",
},
{
name: "gcp cloud provider",
cloud: CCGCP,
expectedType: "vtpm",
},
{
name: "default cloud provider",
cloud: "",
expectedType: "vtpm",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cli := &CLI{}
cmd := cli.NewValidateAttestationValidationCmd()
if err := cmd.Flags().Set("cloud", tt.cloud); err != nil {
t.Fatalf("Failed to set cloud flag: %v", err)
}
cloud, _ := cmd.Flags().GetString("cloud")
assert.Equal(t, tt.cloud, cloud)
assert.Contains(t, cmd.Short, tt.cloud)
})
}
}
func TestFileOperationErrors(t *testing.T) {
defer setupTestEnvironment()()
t.Run("file close error handling", func(t *testing.T) {
tempFile := createTempFile(t, []byte("test content"))
defer os.Remove(tempFile)
assert.True(t, true)
})
t.Run("file write error handling", func(t *testing.T) {
tempFile := createTempFile(t, []byte("test content"))
defer os.Remove(tempFile)
err := os.Chmod(tempFile, 0o444)
require.NoError(t, err)
err = os.WriteFile(tempFile, []byte("new content"), 0o644)
assert.Error(t, err)
})
t.Run("file read error handling", func(t *testing.T) {
tempDir := t.TempDir()
_, err := os.ReadFile(tempDir)
assert.Error(t, err)
})
}
func TestContextCancellation(t *testing.T) {
defer setupTestEnvironment()()
mockSDK := new(mocks.SDK)
cli := &CLI{agentSDK: mockSDK}
ctx, cancel := context.WithCancel(context.Background())
cancel()
mockSDK.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(context.Canceled)
cmd := cli.NewGetAttestationCmd()
cmd.SetContext(ctx)
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
cmd.SetArgs([]string{"snp", "--tee", teeNonceHex})
err := cmd.Execute()
assert.NoError(t, err)
assert.Contains(t, buf.String(), "Failed to get attestation due to error")
}
+169
View File
@@ -0,0 +1,169 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"bytes"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
)
func TestCLI_NewIMAMeasurementsCmd(t *testing.T) {
testCases := []struct {
name string
args []string
connectErr error
mockIMAData string
mockError error
expectedFilename string
expectedOutput []string
expectedError []string
shouldCreateFile bool
fileCreationError bool
invalidDigestData bool
setupCustomFile func(filename string) error
}{
{
name: "successful_retrieval_default_filename",
args: []string{},
connectErr: nil,
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
mockError: nil,
expectedFilename: imaMeasurementsFilename,
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "PCR10 = 0000000000000000000000000000000000000000", "Measurements file verified!"},
shouldCreateFile: true,
},
{
name: "successful_retrieval_custom_filename",
args: []string{"custom_ima_file.txt"},
connectErr: nil,
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
mockError: nil,
expectedFilename: "custom_ima_file.txt",
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "custom_ima_file.txt", "Measurements file verified!"},
shouldCreateFile: true,
},
{
name: "connection_error",
args: []string{},
connectErr: fmt.Errorf("connection failed"),
expectedError: []string{"Failed to connect to agent: connection failed ❌"},
},
{
name: "file_creation_error",
args: []string{"/invalid/path/file.txt"},
connectErr: nil,
fileCreationError: true,
expectedError: []string{"Error creating imaMeasurements file:"},
},
{
name: "sdk_error",
args: []string{},
connectErr: nil,
mockError: fmt.Errorf("SDK communication failed"),
expectedError: []string{"Error retrieving Linux IMA measurements file: SDK communication failed ❌"},
},
{
name: "verification_failure_wrong_pcr",
args: []string{},
connectErr: nil,
mockIMAData: "10 9999999999999999999999999999999999999999 ima-ng sha1:0000000000000000000000000000000000000000 /usr/bin/test",
mockError: nil,
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully"},
expectedError: []string{"Measurements file not verified ❌"},
shouldCreateFile: true,
},
{
name: "empty_measurements_file",
args: []string{},
connectErr: nil,
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
mockError: nil,
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
shouldCreateFile: true,
},
{
name: "measurements_with_non_pcr10_entries",
args: []string{},
connectErr: nil,
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
mockError: nil,
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
shouldCreateFile: true,
},
{
name: "measurements_with_zero_digest_replacement",
args: []string{},
connectErr: nil,
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
mockError: nil,
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
shouldCreateFile: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockSDK := new(mocks.SDK)
cli := &CLI{
agentSDK: mockSDK,
connectErr: tc.connectErr,
}
if tc.connectErr == nil && !tc.fileCreationError {
mockSDK.On("IMAMeasurements", mock.Anything, mock.Anything).Return([]byte(tc.mockIMAData), tc.mockError)
}
cmd := cli.NewIMAMeasurementsCmd()
var output bytes.Buffer
cmd.SetOut(&output)
cmd.SetErr(&output)
expectedFilename := tc.expectedFilename
if expectedFilename == "" {
if len(tc.args) > 0 {
expectedFilename = tc.args[0]
} else {
expectedFilename = imaMeasurementsFilename
}
}
if tc.setupCustomFile != nil {
err := tc.setupCustomFile(expectedFilename)
assert.NoError(t, err)
}
cmd.SetArgs(tc.args)
err := cmd.Execute()
assert.NoError(t, err, "Command execution failed")
outputStr := output.String()
for _, expectedMsg := range tc.expectedOutput {
assert.Contains(t, outputStr, expectedMsg, "Expected output message not found")
}
for _, expectedErr := range tc.expectedError {
assert.Contains(t, outputStr, expectedErr, "Expected error message not found")
}
if tc.shouldCreateFile && tc.connectErr == nil && !tc.fileCreationError && tc.mockError == nil {
if _, err := os.Stat(expectedFilename); err == nil {
os.Remove(expectedFilename)
}
}
if tc.connectErr == nil && !tc.fileCreationError {
mockSDK.AssertExpectations(t)
}
})
}
}
+10 -6
View File
@@ -38,9 +38,11 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
Example: `create-vm`,
Args: cobra.ExactArgs(0),
Run: func(cmd *cobra.Command, args []string) {
if err := c.InitializeManagerClient(cmd); err != nil {
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
return
if c.managerClient == nil || c.connectErr != nil {
if err := c.InitializeManagerClient(cmd); err != nil {
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
return
}
}
defer c.Close()
@@ -74,7 +76,7 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
cmd.Flags().StringVar(&agentCVMServerCA, serverCA, "", "CVM server CA")
cmd.Flags().StringVar(&agentCVMClientKey, clientKey, "", "CVM client key")
cmd.Flags().StringVar(&agentCVMClientCrt, clientCrt, "", "CVM client crt")
cmd.Flags().StringVar(&agentCVMCaUrl, agentCVMCaUrl, "", "CVM CA service URL")
cmd.Flags().StringVar(&agentCVMCaUrl, caUrl, "", "CVM CA service URL")
cmd.Flags().StringVar(&agentLogLevel, logLevel, "", "Agent Log level")
cmd.Flags().DurationVar(&ttl, ttlFlag, 0, "TTL for the VM")
if err := cmd.MarkFlagRequired(serverURL); err != nil {
@@ -92,8 +94,10 @@ func (c *CLI) NewRemoveVMCmd() *cobra.Command {
Example: `remove-vm <cvm_id>`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
if err := c.InitializeManagerClient(cmd); err == nil {
defer c.Close()
if c.managerClient == nil || c.connectErr != nil {
if err := c.InitializeManagerClient(cmd); err == nil {
defer c.Close()
}
}
if c.connectErr != nil {
+600
View File
@@ -0,0 +1,600 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"bytes"
"errors"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/mocks"
"google.golang.org/protobuf/types/known/emptypb"
)
func TestCLI_NewCreateVMCmd(t *testing.T) {
tests := []struct {
name string
setupMock func(*mocks.ManagerServiceClient)
setupCLI func(*CLI)
setupFiles func(string) error
flags map[string]string
expectedOutput string
expectedError string
expectError bool
}{
{
name: "successful VM creation with all flags",
setupMock: func(m *mocks.ManagerServiceClient) {
m.On("CreateVm", mock.Anything, mock.MatchedBy(func(req *manager.CreateReq) bool {
return req.AgentCvmServerUrl == "https://server.com" &&
req.AgentLogLevel == "debug" &&
req.AgentCvmCaUrl == "https://ca.com" &&
req.Ttl == "1h0m0s" &&
string(req.AgentCvmServerCaCert) == "ca-cert-content" &&
string(req.AgentCvmClientKey) == "client-key-content" &&
string(req.AgentCvmClientCert) == "client-cert-content"
})).Return(&manager.CreateRes{
CvmId: "vm-123",
ForwardedPort: "8080",
}, nil)
},
setupCLI: func(cli *CLI) {
},
setupFiles: func(tmpDir string) error {
files := map[string]string{
"server-ca.pem": "ca-cert-content",
"client-key.pem": "client-key-content",
"client-crt.pem": "client-cert-content",
}
for filename, content := range files {
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
return err
}
}
return nil
},
flags: map[string]string{
"server-url": "https://server.com",
"server-ca": "server-ca.pem",
"client-key": "client-key.pem",
"client-crt": "client-crt.pem",
"ca-url": "https://ca.com",
"log-level": "debug",
"ttl": "1h",
},
expectedOutput: "✅ Virtual machine created successfully with id vm-123 and port 8080",
expectError: false,
},
{
name: "successful VM creation with minimal flags",
setupMock: func(m *mocks.ManagerServiceClient) {
m.On("CreateVm", mock.Anything, mock.MatchedBy(func(req *manager.CreateReq) bool {
return req.AgentCvmServerUrl == "https://server.com" &&
req.AgentLogLevel == "" &&
req.AgentCvmCaUrl == "" &&
req.Ttl == "" &&
len(req.AgentCvmServerCaCert) == 0 &&
len(req.AgentCvmClientKey) == 0 &&
len(req.AgentCvmClientCert) == 0
})).Return(&manager.CreateRes{
CvmId: "vm-456",
ForwardedPort: "9090",
}, nil)
},
setupCLI: func(cli *CLI) {
},
setupFiles: func(tmpDir string) error {
return nil // No files needed for minimal test
},
flags: map[string]string{
"server-url": "https://server.com",
},
expectedOutput: "✅ Virtual machine created successfully with id vm-456 and port 9090",
expectError: false,
},
{
name: "manager client initialization failure",
setupMock: func(m *mocks.ManagerServiceClient) {
// No expectations set as initialization fails
},
setupCLI: func(cli *CLI) {
cli.connectErr = errors.New("connection failed")
},
setupFiles: func(tmpDir string) error {
return nil
},
flags: map[string]string{
"server-url": "https://server.com",
},
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
expectError: true,
},
{
name: "certificate loading failure",
setupMock: func(m *mocks.ManagerServiceClient) {
// No expectations set as cert loading fails
},
setupCLI: func(cli *CLI) {
},
setupFiles: func(tmpDir string) error {
return nil // Don't create the cert file
},
flags: map[string]string{
"server-url": "https://server.com",
"server-ca": "nonexistent-ca.pem",
},
expectedError: "Error loading certs:",
expectError: true,
},
{
name: "CreateVm API call failure",
setupMock: func(m *mocks.ManagerServiceClient) {
m.On("CreateVm", mock.Anything, mock.Anything).Return(nil, errors.New("API error"))
},
setupCLI: func(cli *CLI) {
},
setupFiles: func(tmpDir string) error {
return nil
},
flags: map[string]string{
"server-url": "https://server.com",
},
expectedError: "Error creating virtual machine: API error ❌",
expectError: true,
},
{
name: "missing required server-url flag",
setupMock: func(m *mocks.ManagerServiceClient) {
// No expectations set as command validation fails
},
setupCLI: func(cli *CLI) {
},
setupFiles: func(tmpDir string) error {
return nil
},
flags: map[string]string{}, // No server-url flag
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "cli-test-")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
oldDir, err := os.Getwd()
require.NoError(t, err)
err = os.Chdir(tmpDir)
require.NoError(t, err)
t.Cleanup(func() {
err := os.Chdir(oldDir)
require.NoError(t, err)
})
err = tt.setupFiles(tmpDir)
require.NoError(t, err)
mockClient := new(mocks.ManagerServiceClient)
tt.setupMock(mockClient)
mockCLI := &CLI{
managerClient: mockClient,
}
tt.setupCLI(mockCLI)
cmd := mockCLI.NewCreateVMCmd()
for flag, value := range tt.flags {
err := cmd.Flags().Set(flag, value)
require.NoError(t, err)
}
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err = cmd.Execute()
if tt.expectError {
if tt.expectedError != "" {
assert.Contains(t, buf.String(), tt.expectedError)
}
} else {
assert.NoError(t, err)
if tt.expectedOutput != "" {
assert.Contains(t, buf.String(), tt.expectedOutput)
}
}
mockClient.AssertExpectations(t)
})
}
}
func TestCLI_NewRemoveVMCmd(t *testing.T) {
tests := []struct {
name string
setupMock func(*mocks.ManagerServiceClient)
setupCLI func(*CLI)
args []string
expectedOutput string
expectedError string
expectError bool
}{
{
name: "successful VM removal",
setupMock: func(m *mocks.ManagerServiceClient) {
m.On("RemoveVm", mock.Anything, &manager.RemoveReq{
CvmId: "vm-123",
}).Return(&emptypb.Empty{}, nil)
},
setupCLI: func(cli *CLI) {
},
args: []string{"vm-123"},
expectedOutput: "✅ Virtual machine removed successfully",
expectError: false,
},
{
name: "manager client initialization failure",
setupMock: func(m *mocks.ManagerServiceClient) {
// No expectations set as initialization fails
},
setupCLI: func(cli *CLI) {
cli.connectErr = errors.New("connection failed")
},
args: []string{"vm-123"},
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
expectError: true,
},
{
name: "RemoveVm API call failure",
setupMock: func(m *mocks.ManagerServiceClient) {
m.On("RemoveVm", mock.Anything, &manager.RemoveReq{
CvmId: "vm-456",
}).Return(nil, errors.New("removal failed"))
},
setupCLI: func(cli *CLI) {
},
args: []string{"vm-456"},
expectedError: "Error removing virtual machine: removal failed ❌",
expectError: true,
},
{
name: "missing VM ID argument",
setupMock: func(m *mocks.ManagerServiceClient) {
// No expectations set as command validation fails
},
setupCLI: func(cli *CLI) {
},
args: []string{}, // No VM ID provided
expectError: true,
},
{
name: "too many arguments",
setupMock: func(m *mocks.ManagerServiceClient) {
// No expectations set as command validation fails
},
setupCLI: func(cli *CLI) {
},
args: []string{"vm-123", "extra-arg"},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := new(mocks.ManagerServiceClient)
tt.setupMock(mockClient)
mockCLI := &CLI{
managerClient: mockClient,
}
tt.setupCLI(mockCLI)
cmd := mockCLI.NewRemoveVMCmd()
cmd.SetArgs(tt.args)
var buf bytes.Buffer
cmd.SetOut(&buf)
cmd.SetErr(&buf)
err := cmd.Execute()
if tt.expectError {
if tt.expectedError != "" {
assert.Contains(t, buf.String(), tt.expectedError)
}
} else {
assert.NoError(t, err)
if tt.expectedOutput != "" {
assert.Contains(t, buf.String(), tt.expectedOutput)
}
}
mockClient.AssertExpectations(t)
})
}
}
func TestFileReader(t *testing.T) {
tests := []struct {
name string
setupFile func(string) (string, error)
path string
expectedResult []byte
expectError bool
}{
{
name: "successful file read",
setupFile: func(tmpDir string) (string, error) {
filePath := filepath.Join(tmpDir, "test.txt")
err := os.WriteFile(filePath, []byte("test content"), 0o644)
return filePath, err
},
expectedResult: []byte("test content"),
expectError: false,
},
{
name: "empty path returns nil",
setupFile: func(tmpDir string) (string, error) {
return "", nil
},
path: "",
expectedResult: nil,
expectError: false,
},
{
name: "nonexistent file returns error",
setupFile: func(tmpDir string) (string, error) {
return filepath.Join(tmpDir, "nonexistent.txt"), nil
},
expectedResult: nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "fileReader-test-")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
filePath, err := tt.setupFile(tmpDir)
require.NoError(t, err)
if tt.path != "" {
filePath = tt.path
}
result, err := fileReader(filePath)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedResult, result)
}
})
}
}
func TestLoadCerts(t *testing.T) {
tests := []struct {
name string
setupFiles func(string) error
setupGlobal func(string)
expectError bool
validate func(*testing.T, *manager.CreateReq)
}{
{
name: "successful cert loading with all files",
setupFiles: func(tmpDir string) error {
files := map[string]string{
"client.key": "client-key-content",
"client.crt": "client-cert-content",
"server.ca": "server-ca-content",
}
for filename, content := range files {
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
return err
}
}
return nil
},
setupGlobal: func(tmpDir string) {
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
agentCVMClientCrt = filepath.Join(tmpDir, "client.crt")
agentCVMServerCA = filepath.Join(tmpDir, "server.ca")
},
expectError: false,
validate: func(t *testing.T, req *manager.CreateReq) {
assert.Equal(t, []byte("client-key-content"), req.AgentCvmClientKey)
assert.Equal(t, []byte("client-cert-content"), req.AgentCvmClientCert)
assert.Equal(t, []byte("server-ca-content"), req.AgentCvmServerCaCert)
},
},
{
name: "successful cert loading with empty paths",
setupFiles: func(tmpDir string) error {
return nil
},
setupGlobal: func(tmpDir string) {
agentCVMClientKey = ""
agentCVMClientCrt = ""
agentCVMServerCA = ""
},
expectError: false,
validate: func(t *testing.T, req *manager.CreateReq) {
assert.Nil(t, req.AgentCvmClientKey)
assert.Nil(t, req.AgentCvmClientCert)
assert.Nil(t, req.AgentCvmServerCaCert)
},
},
{
name: "client key file read error",
setupFiles: func(tmpDir string) error {
return nil // Don't create client key file
},
setupGlobal: func(tmpDir string) {
agentCVMClientKey = filepath.Join(tmpDir, "nonexistent.key")
agentCVMClientCrt = ""
agentCVMServerCA = ""
},
expectError: true,
},
{
name: "client cert file read error",
setupFiles: func(tmpDir string) error {
// Create client key but not cert
return os.WriteFile(filepath.Join(tmpDir, "client.key"), []byte("key-content"), 0o644)
},
setupGlobal: func(tmpDir string) {
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
agentCVMClientCrt = filepath.Join(tmpDir, "nonexistent.crt")
agentCVMServerCA = ""
},
expectError: true,
},
{
name: "server CA file read error",
setupFiles: func(tmpDir string) error {
files := map[string]string{
"client.key": "client-key-content",
"client.crt": "client-cert-content",
}
for filename, content := range files {
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
return err
}
}
return nil
},
setupGlobal: func(tmpDir string) {
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
agentCVMClientCrt = filepath.Join(tmpDir, "client.crt")
agentCVMServerCA = filepath.Join(tmpDir, "nonexistent.ca")
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "loadCerts-test-")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
err = tt.setupFiles(tmpDir)
require.NoError(t, err)
// Store original global variables
origClientKey := agentCVMClientKey
origClientCrt := agentCVMClientCrt
origServerCA := agentCVMServerCA
// Setup global variables for test
tt.setupGlobal(tmpDir)
// Restore original values after test
defer func() {
agentCVMClientKey = origClientKey
agentCVMClientCrt = origClientCrt
agentCVMServerCA = origServerCA
}()
result, err := loadCerts()
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
if tt.validate != nil {
tt.validate(t, result)
}
}
})
}
}
func TestCommandCreation(t *testing.T) {
cli := &CLI{}
t.Run("create-vm command creation", func(t *testing.T) {
cmd := cli.NewCreateVMCmd()
assert.NotNil(t, cmd)
assert.Equal(t, "create-vm", cmd.Use)
assert.Equal(t, "Create a new virtual machine", cmd.Short)
// Check that required flags are set
flag := cmd.Flags().Lookup("server-url")
assert.NotNil(t, flag)
// Note: We can't easily test MarkFlagRequired in unit tests
})
t.Run("remove-vm command creation", func(t *testing.T) {
cmd := cli.NewRemoveVMCmd()
assert.NotNil(t, cmd)
assert.Equal(t, "remove-vm", cmd.Use)
assert.Equal(t, "Remove a virtual machine", cmd.Short)
})
}
func TestTTLHandling(t *testing.T) {
tests := []struct {
name string
ttlInput string
expectedTTL time.Duration
expectError bool
}{
{
name: "valid duration",
ttlInput: "1h30m",
expectedTTL: time.Hour + 30*time.Minute,
expectError: false,
},
{
name: "zero duration",
ttlInput: "0",
expectedTTL: 0,
expectError: false,
},
{
name: "empty string",
ttlInput: "",
expectedTTL: 0,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockCLI := &CLI{
managerClient: new(mocks.ManagerServiceClient),
}
cmd := mockCLI.NewCreateVMCmd()
if tt.ttlInput != "" {
err := cmd.Flags().Set("ttl", tt.ttlInput)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedTTL, ttl)
}
}
})
}
}
+3 -1
View File
@@ -62,5 +62,7 @@ func (c *CLI) InitializeManagerClient(cmd *cobra.Command) error {
}
func (c *CLI) Close() {
c.client.Close()
if c.client != nil {
c.client.Close()
}
}