mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
* Implement IMAMeasurements method in agentSDK and add corresponding unit tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for NewIMAMeasurements command in CLI Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add error assertion for command execution in NewIMAMeasurements test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix nil pointer dereference in Close method and update NewCreateVMCmd logic for manager client initialization Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor file permission settings to use octal notation and improve cleanup handling in NewCreateVMCmd test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive unit tests for state machine functionality Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add mock implementation for Algorithm interface and corresponding test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor file permission settings to use octal notation in TestStopComputationIntegration Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove redundant reset test cases from TestStateMachine_Reset Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix race condition in action call verification in TestStateMachine_HandleEvent Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance state machine with reset functionality and improve thread safety in event handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in state machine start function during tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove concurrent reset and send event test from state machine tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove error logging for Start function in transition tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add mock implementations for AgentService_IMAMeasurementsClient and Service Shutdown method; enhance progress tests for IMA measurements handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for FileStorage functionality including loading, saving, and concurrent access Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance tests by adding dataset and algorithm hashes in handleRunReqChunks; improve error handling in TestFileStorage_ErrorHandling cleanup Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance TestManagerClient_Process by adding new test cases for Agent state and Disconnect requests; update setupMocks to include grpcClient Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix graceful shutdown in gRPC server by adding nil checks for health and server instances Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance TestAttestation by adding mock expectations for VTpmAttestation and Attestation methods; update service call to include platform parameter Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance gRPC Server by adding synchronization for start/stop methods; prevent multiple starts and ensure graceful shutdown Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for gRPC server methods including VM creation, removal, and info retrieval Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add tests for SEVSNP and TDX host capabilities; remove unused vsock code Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add a newline for better readability in vm_test.go Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add integration tests for gRPC client in cvm_test.go Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove unused vsock dependencies and add comprehensive unit tests for GCP attestation functions Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Skip GCP tests if credentials are not set Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add tests for error handling in attestation configuration and GCP commands Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in Azure VM test response writing Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Skip tests in GCP functions if credentials are not set Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive unit tests for Azure attestation provider and verifier Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for TPM functionality and improve error handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for attestation functionality and improve error handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add validation for teeNonce in TeeAttestation and implement comprehensive tests for provider methods Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor error messages in TDX attestation tests for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix error message in TeeAttestation test for valid nonce case Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add MeasurementProvider mock and update mockery configuration Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add logging for product in parseUints and rename test functions for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor TestSevsnpverify to reset configuration and improve error logging Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
85a2b7a6c8
commit
4e8057f481
@@ -3,6 +3,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"os"
|
||||
@@ -131,3 +132,247 @@ func TestNewAddHostDataCmd(t *testing.T) {
|
||||
assert.Equal(t, "hostdata <host-data> <attestation_policy.json>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
}
|
||||
|
||||
func TestChangeAttestationConfigurationFileErrors(t *testing.T) {
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
err := changeAttestationConfiguration("nonexistent.json", base64.StdEncoding.EncodeToString(make([]byte, measurementLength)), measurementLength, measurementField)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error while reading the attestation policy file")
|
||||
})
|
||||
|
||||
t.Run("Invalid JSON Content", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "invalid.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("invalid json"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = changeAttestationConfiguration(tmpfile.Name(), base64.StdEncoding.EncodeToString(make([]byte, measurementLength)), measurementLength, measurementField)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal json")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewGCPAttestationPolicy(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewGCPAttestationPolicy()
|
||||
|
||||
assert.Equal(t, "gcp", cmd.Use)
|
||||
assert.Equal(t, "Get attestation policy for GCP CVM", cmd.Short)
|
||||
assert.Equal(t, "gcp <bin_vtmp_attestation_report_file> <vcpu_count>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{"nonexistent.bin", "4"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error reading attestation report file")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Invalid vCPU Count", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation.bin")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("dummy content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name(), "invalid"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error converting vCPU count to integer")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Invalid Attestation Data", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation.bin")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("invalid protobuf data"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name(), "4"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error unmarshaling attestation report")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewDownloadGCPOvmfFile(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewDownloadGCPOvmfFile()
|
||||
|
||||
assert.Equal(t, "download", cmd.Use)
|
||||
assert.Equal(t, "Download GCP OVMF file", cmd.Short)
|
||||
assert.Equal(t, "download <bin_vtmp_attestation_report_file>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{"nonexistent.bin"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error reading attestation report file")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Invalid Attestation Data", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation.bin")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("invalid protobuf data"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name()})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error unmarshaling attestation report")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewAzureAttestationPolicy(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewAzureAttestationPolicy()
|
||||
|
||||
assert.Equal(t, "azure", cmd.Use)
|
||||
assert.Equal(t, "Get attestation policy for Azure CVM", cmd.Short)
|
||||
assert.Equal(t, "azure <azure_maa_token_file> <product_name>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
|
||||
flag := cmd.Flags().Lookup("policy")
|
||||
assert.NotNil(t, flag)
|
||||
assert.Equal(t, "Policy of the guest CVM", flag.Usage)
|
||||
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{"nonexistent.token", "test-product"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error reading attestation report file")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Valid Token File", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "token.maa")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("dummy.token.content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer os.Remove("attestation_policy.json")
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name(), "test-product"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Custom Policy Flag", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "token.maa")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("dummy.token.content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.SetArgs([]string{"--policy", "123456", tmpfile.Name(), "test-product"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
flag := cmd.Flags().Lookup("policy")
|
||||
assert.NotNil(t, flag)
|
||||
assert.Equal(t, "123456", flag.Value.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestCommandErrorHandling(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
|
||||
t.Run("Measurement Command Error", func(t *testing.T) {
|
||||
cmd := cli.NewAddMeasurementCmd()
|
||||
cmd.SetArgs([]string{"invalid-base64", "nonexistent.json"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error could not change measurement data")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Host Data Command Error", func(t *testing.T) {
|
||||
cmd := cli.NewAddHostDataCmd()
|
||||
cmd.SetArgs([]string{"invalid-base64", "nonexistent.json"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error could not change host data")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -331,6 +331,7 @@ func parseUints() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Policy.Product.MachineStepping = wrapperspb.UInt32(uint32(num))
|
||||
} else {
|
||||
num, err := strconv.ParseUint(stepping[2:], base, 8)
|
||||
|
||||
@@ -0,0 +1,870 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
tpmAttest "github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/google/go-tpm-tools/proto/tpm"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/mocks"
|
||||
"google.golang.org/protobuf/encoding/prototext"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestAddSEVSNPVerificationOptions(t *testing.T) {
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
}
|
||||
|
||||
result := addSEVSNPVerificationOptions(cmd)
|
||||
|
||||
assert.Equal(t, cmd, result)
|
||||
|
||||
// Check that important flags are added
|
||||
flags := []string{
|
||||
"host_data",
|
||||
"family_id",
|
||||
"image_id",
|
||||
"report_id",
|
||||
"report_id_ma",
|
||||
"measurement",
|
||||
"chip_id",
|
||||
"minimum_tcb",
|
||||
"minimum_lauch_tcb",
|
||||
"guest_policy",
|
||||
"minimum_guest_svn",
|
||||
"minimum_build",
|
||||
"check_crl",
|
||||
"timeout",
|
||||
"max_retry_delay",
|
||||
"require_author_key",
|
||||
"require_id_block",
|
||||
"platform_info",
|
||||
"minimum_version",
|
||||
"trusted_author_keys",
|
||||
"trusted_author_key_hashes",
|
||||
"trusted_id_keys",
|
||||
"trusted_id_key_hashes",
|
||||
"product",
|
||||
"stepping",
|
||||
"CA_bundles_paths",
|
||||
"CA_bundles",
|
||||
}
|
||||
|
||||
for _, flagName := range flags {
|
||||
flag := cmd.Flags().Lookup(flagName)
|
||||
assert.NotNil(t, flag, "Flag %s should exist", flagName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCfg func()
|
||||
expectErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid empty config",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "CA bundles without product name",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{},
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
CabundlePaths: []string{"test.pem"},
|
||||
ProductLine: "",
|
||||
},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "product name must be set if CA bundles are provided",
|
||||
},
|
||||
{
|
||||
name: "invalid report_data length",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
ReportData: []byte("invalid"),
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "report_data",
|
||||
},
|
||||
{
|
||||
name: "invalid host_data length",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
HostData: []byte("invalid"),
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "host_data",
|
||||
},
|
||||
{
|
||||
name: "invalid family_id length",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
FamilyId: []byte("invalid"),
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "family_id",
|
||||
},
|
||||
{
|
||||
name: "invalid image_id length",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
ImageId: []byte("invalid"),
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "image_id",
|
||||
},
|
||||
{
|
||||
name: "invalid trusted author key hash",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
TrustedAuthorKeyHashes: [][]byte{[]byte("invalid")},
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "trusted_author_key_hash",
|
||||
},
|
||||
{
|
||||
name: "invalid trusted id key hash",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
TrustedIdKeyHashes: [][]byte{[]byte("invalid")},
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "trusted_id_key_hash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupCfg()
|
||||
err := validateInput()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustedKeys(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authorKeyFile := filepath.Join(tempDir, "author.pem")
|
||||
idKeyFile := filepath.Join(tempDir, "id.pem")
|
||||
nonExistentFile := filepath.Join(tempDir, "nonexistent.pem")
|
||||
|
||||
authorKeyContent := "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAOI..."
|
||||
idKeyContent := "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAOI..."
|
||||
|
||||
require.NoError(t, os.WriteFile(authorKeyFile, []byte(authorKeyContent), 0o644))
|
||||
require.NoError(t, os.WriteFile(idKeyFile, []byte(idKeyContent), 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
trustedAuthorKeys []string
|
||||
trustedIdKeys []string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid files",
|
||||
trustedAuthorKeys: []string{authorKeyFile},
|
||||
trustedIdKeys: []string{idKeyFile},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent author key file",
|
||||
trustedAuthorKeys: []string{nonExistentFile},
|
||||
trustedIdKeys: []string{},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "nonexistent id key file",
|
||||
trustedAuthorKeys: []string{},
|
||||
trustedIdKeys: []string{nonExistentFile},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty file lists",
|
||||
trustedAuthorKeys: []string{},
|
||||
trustedIdKeys: []string{},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
trustedAuthorKeys = tt.trustedAuthorKeys
|
||||
trustedIdKeys = tt.trustedIdKeys
|
||||
|
||||
err := parseTrustedKeys()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if len(tt.trustedAuthorKeys) > 0 {
|
||||
assert.Len(t, cfg.Policy.TrustedAuthorKeys, len(tt.trustedAuthorKeys))
|
||||
assert.Equal(t, []byte(authorKeyContent), cfg.Policy.TrustedAuthorKeys[0])
|
||||
}
|
||||
if len(tt.trustedIdKeys) > 0 {
|
||||
assert.Len(t, cfg.Policy.TrustedIdKeys, len(tt.trustedIdKeys))
|
||||
assert.Equal(t, []byte(idKeyContent), cfg.Policy.TrustedIdKeys[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUints(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stepping string
|
||||
platformInfo string
|
||||
expectErr bool
|
||||
expectedStep *uint32
|
||||
expectedPlatform *uint64
|
||||
}{
|
||||
{
|
||||
name: "empty values",
|
||||
stepping: "",
|
||||
platformInfo: "",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "decimal values",
|
||||
stepping: "5",
|
||||
platformInfo: "10",
|
||||
expectErr: false,
|
||||
expectedStep: uint32Ptr(5),
|
||||
expectedPlatform: uint64Ptr(10),
|
||||
},
|
||||
{
|
||||
name: "hex values",
|
||||
stepping: "0x5",
|
||||
platformInfo: "0xa",
|
||||
expectErr: false,
|
||||
expectedStep: uint32Ptr(5),
|
||||
expectedPlatform: uint64Ptr(10),
|
||||
},
|
||||
{
|
||||
name: "octal values",
|
||||
stepping: "0o7",
|
||||
platformInfo: "0o12",
|
||||
expectErr: false,
|
||||
expectedStep: uint32Ptr(7),
|
||||
expectedPlatform: uint64Ptr(10),
|
||||
},
|
||||
{
|
||||
name: "binary values",
|
||||
stepping: "0b101",
|
||||
platformInfo: "0b1010",
|
||||
expectErr: false,
|
||||
expectedStep: uint32Ptr(5),
|
||||
expectedPlatform: uint64Ptr(10),
|
||||
},
|
||||
{
|
||||
name: "invalid stepping",
|
||||
stepping: "invalid",
|
||||
platformInfo: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid platform info",
|
||||
stepping: "",
|
||||
platformInfo: "invalid",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{Product: &sevsnp.SevProduct{}}, RootOfTrust: &check.RootOfTrust{}}
|
||||
stepping = tt.stepping
|
||||
platformInfo = tt.platformInfo
|
||||
|
||||
err := parseUints()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedStep != nil {
|
||||
assert.Equal(t, *tt.expectedStep, cfg.Policy.Product.MachineStepping.Value)
|
||||
}
|
||||
if tt.expectedPlatform != nil {
|
||||
assert.Equal(t, *tt.expectedPlatform, cfg.Policy.PlatformInfo.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBase(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected int
|
||||
}{
|
||||
{"0x10", 16},
|
||||
{"0o10", 8},
|
||||
{"0b10", 2},
|
||||
{"10", 10},
|
||||
{"", 10},
|
||||
{"abc", 10},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := getBase(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
validConfig := map[string]interface{}{
|
||||
"rootOfTrust": map[string]interface{}{
|
||||
"product": "test_product",
|
||||
"cabundlePaths": []string{"test_path"},
|
||||
"cabundles": []string{"test_bundle"},
|
||||
"checkCrl": true,
|
||||
"disallowNetwork": true,
|
||||
},
|
||||
"policy": map[string]interface{}{
|
||||
"minimumGuestSvn": 1,
|
||||
"policy": "1",
|
||||
"minimumBuild": 1,
|
||||
"minimumVersion": "0.90",
|
||||
"requireAuthorKey": true,
|
||||
"requireIdBlock": true,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupConfig func() string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty config string",
|
||||
setupConfig: func() string {
|
||||
return ""
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config file",
|
||||
setupConfig: func() string {
|
||||
configFile := filepath.Join(tempDir, "valid_config.json")
|
||||
configBytes, err := json.Marshal(validConfig)
|
||||
assert.NoError(t, err)
|
||||
if err := os.WriteFile(configFile, configBytes, 0o644); err != nil {
|
||||
t.Errorf("failed to write config file: %v", err)
|
||||
}
|
||||
return configFile
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent config file",
|
||||
setupConfig: func() string {
|
||||
return filepath.Join(tempDir, "nonexistent.json")
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON config",
|
||||
setupConfig: func() string {
|
||||
configFile := filepath.Join(tempDir, "invalid_config.json")
|
||||
if err := os.WriteFile(configFile, []byte("invalid json"), 0o644); err != nil {
|
||||
t.Errorf("failed to write invalid config file: %v", err)
|
||||
}
|
||||
return configFile
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
cfgString = tt.setupConfig()
|
||||
|
||||
err := parseConfig()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg.Policy)
|
||||
assert.NotNil(t, cfg.RootOfTrust)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHashes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
trustedAuthorHashes []string
|
||||
trustedIdKeyHashes []string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid hashes",
|
||||
trustedAuthorHashes: []string{"deadbeef", "cafebabe"},
|
||||
trustedIdKeyHashes: []string{"12345678", "87654321"},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty hashes",
|
||||
trustedAuthorHashes: []string{},
|
||||
trustedIdKeyHashes: []string{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid author hash",
|
||||
trustedAuthorHashes: []string{"invalid_hex"},
|
||||
trustedIdKeyHashes: []string{},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid id key hash",
|
||||
trustedAuthorHashes: []string{},
|
||||
trustedIdKeyHashes: []string{"invalid_hex"},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
trustedAuthorHashes = tt.trustedAuthorHashes
|
||||
trustedIdKeyHashes = tt.trustedIdKeyHashes
|
||||
|
||||
err := parseHashes()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, cfg.Policy.TrustedAuthorKeyHashes, len(tt.trustedAuthorHashes))
|
||||
assert.Len(t, cfg.Policy.TrustedIdKeyHashes, len(tt.trustedIdKeyHashes))
|
||||
|
||||
for i, hash := range tt.trustedAuthorHashes {
|
||||
expected, _ := hex.DecodeString(hash)
|
||||
assert.Equal(t, expected, cfg.Policy.TrustedAuthorKeyHashes[i])
|
||||
}
|
||||
|
||||
for i, hash := range tt.trustedIdKeyHashes {
|
||||
expected, _ := hex.DecodeString(hash)
|
||||
assert.Equal(t, expected, cfg.Policy.TrustedIdKeyHashes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAttestationFile(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
binaryFile := filepath.Join(tempDir, "attestation.bin")
|
||||
jsonFile := filepath.Join(tempDir, "attestation.json")
|
||||
|
||||
binaryData := make([]byte, 1024)
|
||||
for i := range binaryData {
|
||||
binaryData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
jsonData := &sevsnp.Attestation{
|
||||
Report: &sevsnp.Report{
|
||||
FamilyId: make([]byte, 16),
|
||||
ImageId: make([]byte, 16),
|
||||
ReportData: make([]byte, 64),
|
||||
Measurement: make([]byte, 48),
|
||||
HostData: make([]byte, 32),
|
||||
IdKeyDigest: make([]byte, 48),
|
||||
AuthorKeyDigest: make([]byte, 48),
|
||||
ReportId: make([]byte, 32),
|
||||
ReportIdMa: make([]byte, 32),
|
||||
ChipId: make([]byte, 64),
|
||||
Signature: make([]byte, 512),
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, os.WriteFile(binaryFile, binaryData, 0o644))
|
||||
require.NoError(t, os.WriteFile(jsonFile, jsonBytes, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationFile string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid binary file",
|
||||
attestationFile: binaryFile,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid JSON file",
|
||||
attestationFile: jsonFile,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent file",
|
||||
attestationFile: filepath.Join(tempDir, "nonexistent.bin"),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
attestationFile = tt.attestationFile
|
||||
|
||||
err := parseAttestationFile()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, attestationRaw)
|
||||
assert.NotEmpty(t, attestationRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSevsnpverify(t *testing.T) {
|
||||
trustedAuthorHashes = []string{}
|
||||
trustedIdKeyHashes = []string{}
|
||||
stepping = ""
|
||||
platformInfo = ""
|
||||
tempDir := t.TempDir()
|
||||
cfg = check.Config{Policy: &check.Policy{Product: &sevsnp.SevProduct{}}, RootOfTrust: &check.RootOfTrust{}}
|
||||
|
||||
attestationFile := filepath.Join(tempDir, "attestation.bin")
|
||||
attestationData := make([]byte, abi.ReportSize+100)
|
||||
for i := range attestationData {
|
||||
attestationData[i] = byte(i % 256)
|
||||
}
|
||||
require.NoError(t, os.WriteFile(attestationFile, attestationData, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupMock func(*mocks.Verifier)
|
||||
expectErr bool
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful verification",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifTeeAttestation", mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
expectErr: false,
|
||||
expectedMsg: "Attestation validation and verification is successful!",
|
||||
},
|
||||
{
|
||||
name: "verification failure",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifTeeAttestation", mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
|
||||
},
|
||||
expectErr: true,
|
||||
expectedMsg: "attestation validation and verification failed",
|
||||
},
|
||||
{
|
||||
name: "nonexistent file",
|
||||
args: []string{filepath.Join(tempDir, "nonexistent.bin")},
|
||||
setupMock: func(m *mocks.Verifier) {},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfgString = ""
|
||||
|
||||
mockVerifier := new(mocks.Verifier)
|
||||
tt.setupMock(mockVerifier)
|
||||
|
||||
var output bytes.Buffer
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetOut(&output)
|
||||
|
||||
err := sevsnpverify(cmd, mockVerifier, tt.args)
|
||||
fmt.Println("error1", err)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedMsg != "" {
|
||||
assert.Contains(t, output.String(), tt.expectedMsg)
|
||||
}
|
||||
}
|
||||
|
||||
mockVerifier.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnvTPMAttestation(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
attestation := &tpmAttest.Attestation{
|
||||
Quotes: []*tpm.Quote{
|
||||
{
|
||||
Quote: []byte("test quote"),
|
||||
RawSig: []byte("test signature"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
binaryData, err := proto.Marshal(attestation)
|
||||
require.NoError(t, err)
|
||||
|
||||
binaryFile := filepath.Join(tempDir, "attestation.pb")
|
||||
require.NoError(t, os.WriteFile(binaryFile, binaryData, 0o644))
|
||||
|
||||
textData, err := prototext.Marshal(attestation)
|
||||
require.NoError(t, err)
|
||||
|
||||
textFile := filepath.Join(tempDir, "attestation.txtpb")
|
||||
require.NoError(t, os.WriteFile(textFile, textData, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
format string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "binary protobuf format",
|
||||
args: []string{binaryFile},
|
||||
format: FormatBinaryPB,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "text protobuf format",
|
||||
args: []string{textFile},
|
||||
format: FormatTextProto,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format",
|
||||
args: []string{binaryFile},
|
||||
format: "invalid",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "nonexistent file",
|
||||
args: []string{filepath.Join(tempDir, "nonexistent.pb")},
|
||||
format: FormatBinaryPB,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
format = tt.format
|
||||
|
||||
result, err := returnvTPMAttestation(tt.args)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVtpmSevSnpverify(t *testing.T) {
|
||||
stepping = ""
|
||||
platformInfo = ""
|
||||
trustedAuthorHashes = []string{}
|
||||
trustedIdKeyHashes = []string{}
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
tempDir := t.TempDir()
|
||||
|
||||
attestation := &tpmAttest.Attestation{
|
||||
Quotes: []*tpm.Quote{
|
||||
{
|
||||
Quote: []byte("test quote"),
|
||||
RawSig: []byte("test signature"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
binaryData, err := proto.Marshal(attestation)
|
||||
require.NoError(t, err)
|
||||
|
||||
attestationFile := filepath.Join(tempDir, "vtpm_attestation.pb")
|
||||
require.NoError(t, os.WriteFile(attestationFile, binaryData, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupMock func(*mocks.Verifier)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful verification",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifyAttestation", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "verification failure",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifyAttestation", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
cfgString = ""
|
||||
format = FormatBinaryPB
|
||||
|
||||
mockVerifier := new(mocks.Verifier)
|
||||
tt.setupMock(mockVerifier)
|
||||
|
||||
err := vtpmSevSnpverify(tt.args, mockVerifier)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
mockVerifier.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVtpmverify(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
attestation := &tpmAttest.Attestation{
|
||||
Quotes: []*tpm.Quote{
|
||||
{
|
||||
Quote: []byte("test quote"),
|
||||
RawSig: []byte("test signature"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
binaryData, err := proto.Marshal(attestation)
|
||||
require.NoError(t, err)
|
||||
|
||||
attestationFile := filepath.Join(tempDir, "vtpm_attestation.pb")
|
||||
require.NoError(t, os.WriteFile(attestationFile, binaryData, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupMock func(*mocks.Verifier)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful verification",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifVTpmAttestation", mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "verification failure",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifVTpmAttestation", mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
format = FormatBinaryPB
|
||||
|
||||
mockVerifier := new(mocks.Verifier)
|
||||
tt.setupMock(mockVerifier)
|
||||
|
||||
err := vtpmverify(tt.args, mockVerifier)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
mockVerifier.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func uint64Ptr(v uint64) *uint64 {
|
||||
return &v
|
||||
}
|
||||
+694
-23
@@ -4,10 +4,12 @@ package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
mmocks "github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
||||
@@ -311,26 +314,12 @@ func TestNewValidateAttestationValidationCmd(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type MockMeasurement struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockMeasurement) Run(igvmBinaryPath string) ([]byte, error) {
|
||||
args := m.Called(igvmBinaryPath)
|
||||
return nil, args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMeasurement) Stop() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNewMeasureCmd_RunSuccess(t *testing.T) {
|
||||
cliInstance := &CLI{}
|
||||
mockMeasurement := new(MockMeasurement)
|
||||
mockMeasurement := new(mmocks.MeasurementProvider)
|
||||
cliInstance.measurement = mockMeasurement
|
||||
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return(nil)
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return([]byte{}, nil)
|
||||
|
||||
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
|
||||
buf := new(bytes.Buffer)
|
||||
@@ -346,11 +335,11 @@ func TestNewMeasureCmd_RunSuccess(t *testing.T) {
|
||||
|
||||
func TestNewMeasureCmd_RunError(t *testing.T) {
|
||||
cliInstance := &CLI{}
|
||||
mockMeasurement := new(MockMeasurement)
|
||||
mockMeasurement := new(mmocks.MeasurementProvider)
|
||||
cliInstance.measurement = mockMeasurement
|
||||
expectedError := errors.New("mocked measurement error")
|
||||
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return(expectedError)
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return([]byte{}, expectedError)
|
||||
|
||||
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
|
||||
|
||||
@@ -366,7 +355,7 @@ func TestNewMeasureCmd_RunError(t *testing.T) {
|
||||
mockMeasurement.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
func TestParseConfig1(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation_policy.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
@@ -393,7 +382,7 @@ func TestParseConfig(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseHashes(t *testing.T) {
|
||||
func TestParseHashes1(t *testing.T) {
|
||||
trustedAuthorHashes = []string{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}
|
||||
trustedIdKeyHashes = []string{"fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"}
|
||||
|
||||
@@ -444,7 +433,7 @@ func TestParseFiles(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseUints(t *testing.T) {
|
||||
func TestParseUints1(t *testing.T) {
|
||||
stepping = "10"
|
||||
platformInfo = "0xFF"
|
||||
|
||||
@@ -469,7 +458,7 @@ func TestParseUints(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
func TestValidateInput1(t *testing.T) {
|
||||
cfg = check.Config{}
|
||||
if cfg.Policy == nil {
|
||||
cfg.Policy = &check.Policy{}
|
||||
@@ -494,7 +483,7 @@ func TestValidateInput(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetBase(t *testing.T) {
|
||||
func TestGetBase1(t *testing.T) {
|
||||
assert.Equal(t, 16, getBase("0xFF"))
|
||||
assert.Equal(t, 8, getBase("0o77"))
|
||||
assert.Equal(t, 2, getBase("0b1010"))
|
||||
@@ -716,3 +705,685 @@ func TestDecodeJWTToJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestEnvironment() func() {
|
||||
originalMode := mode
|
||||
originalCfgString := cfgString
|
||||
originalTimeout := timeout
|
||||
originalMaxRetryDelay := maxRetryDelay
|
||||
originalPlatformInfo := platformInfo
|
||||
originalStepping := stepping
|
||||
originalTrustedAuthorKeys := trustedAuthorKeys
|
||||
originalTrustedAuthorHashes := trustedAuthorHashes
|
||||
originalTrustedIdKeys := trustedIdKeys
|
||||
originalTrustedIdKeyHashes := trustedIdKeyHashes
|
||||
originalAttestationFile := attestationFile
|
||||
originalAttestationRaw := attestationRaw
|
||||
originalOutput := output
|
||||
originalNonce := nonce
|
||||
originalFormat := format
|
||||
originalTeeNonce := teeNonce
|
||||
originalTokenNonce := tokenNonce
|
||||
originalGetTextProtoAttestationReport := getTextProtoAttestationReport
|
||||
originalGetAzureTokenJWT := getAzureTokenJWT
|
||||
originalCloud := cloud
|
||||
originalReportData := reportData
|
||||
originalCheckCrl := checkCrl
|
||||
|
||||
mode = ""
|
||||
cfgString = ""
|
||||
timeout = 0
|
||||
maxRetryDelay = 0
|
||||
platformInfo = ""
|
||||
stepping = ""
|
||||
trustedAuthorKeys = []string{}
|
||||
trustedAuthorHashes = []string{}
|
||||
trustedIdKeys = []string{}
|
||||
trustedIdKeyHashes = []string{}
|
||||
attestationFile = ""
|
||||
attestationRaw = []byte{}
|
||||
output = ""
|
||||
nonce = []byte{}
|
||||
format = ""
|
||||
teeNonce = []byte{}
|
||||
tokenNonce = []byte{}
|
||||
getTextProtoAttestationReport = false
|
||||
getAzureTokenJWT = false
|
||||
cloud = ""
|
||||
reportData = []byte{}
|
||||
checkCrl = false
|
||||
|
||||
return func() {
|
||||
mode = originalMode
|
||||
cfgString = originalCfgString
|
||||
timeout = originalTimeout
|
||||
maxRetryDelay = originalMaxRetryDelay
|
||||
platformInfo = originalPlatformInfo
|
||||
stepping = originalStepping
|
||||
trustedAuthorKeys = originalTrustedAuthorKeys
|
||||
trustedAuthorHashes = originalTrustedAuthorHashes
|
||||
trustedIdKeys = originalTrustedIdKeys
|
||||
trustedIdKeyHashes = originalTrustedIdKeyHashes
|
||||
attestationFile = originalAttestationFile
|
||||
attestationRaw = originalAttestationRaw
|
||||
output = originalOutput
|
||||
nonce = originalNonce
|
||||
format = originalFormat
|
||||
teeNonce = originalTeeNonce
|
||||
tokenNonce = originalTokenNonce
|
||||
getTextProtoAttestationReport = originalGetTextProtoAttestationReport
|
||||
getAzureTokenJWT = originalGetAzureTokenJWT
|
||||
cloud = originalCloud
|
||||
reportData = originalReportData
|
||||
checkCrl = originalCheckCrl
|
||||
}
|
||||
}
|
||||
|
||||
func createTempFile(t *testing.T, content []byte) string {
|
||||
tmpfile, err := os.CreateTemp("", "test_*.bin")
|
||||
require.NoError(t, err)
|
||||
defer tmpfile.Close()
|
||||
|
||||
_, err = tmpfile.Write(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tmpfile.Name()
|
||||
}
|
||||
|
||||
func TestNewAttestationCmdEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedOutput string
|
||||
hasSubcommands bool
|
||||
}{
|
||||
{
|
||||
name: "no arguments shows help",
|
||||
args: []string{},
|
||||
expectedOutput: "Get and validate attestations",
|
||||
hasSubcommands: true,
|
||||
},
|
||||
{
|
||||
name: "help flag shows usage",
|
||||
args: []string{"--help"},
|
||||
expectedOutput: "Get and validate attestations",
|
||||
hasSubcommands: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
cmd := cli.NewAttestationCmd()
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, tt.expectedOutput)
|
||||
|
||||
if tt.hasSubcommands {
|
||||
assert.Contains(t, output, "Get and validate attestations")
|
||||
assert.Contains(t, output, "Usage:")
|
||||
assert.Contains(t, output, "Flags:")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAttestationCmdEdgeCases(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupMock func(*mocks.SDK)
|
||||
expectedErr string
|
||||
expectedOut string
|
||||
}{
|
||||
{
|
||||
name: "no arguments provided",
|
||||
args: []string{},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "accepts 1 arg(s), received 0",
|
||||
},
|
||||
{
|
||||
name: "too many arguments",
|
||||
args: []string{"snp", "extra"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "accepts 1 arg(s), received 2",
|
||||
},
|
||||
{
|
||||
name: "invalid attestation type",
|
||||
args: []string{"invalid-type"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "Bad attestation type",
|
||||
},
|
||||
{
|
||||
name: "SNP with missing TEE nonce",
|
||||
args: []string{"snp"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "TEE nonce must be defined for SEV-SNP attestation",
|
||||
},
|
||||
{
|
||||
name: "vTPM with missing nonce",
|
||||
args: []string{"vtpm"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "vTPM nonce must be defined for vTPM attestation",
|
||||
},
|
||||
{
|
||||
name: "Azure token with missing token nonce",
|
||||
args: []string{"azure-token"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "Token nonce must be defined for Azure attestation",
|
||||
},
|
||||
{
|
||||
name: "TEE nonce too large",
|
||||
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce+1))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes",
|
||||
},
|
||||
{
|
||||
name: "vTPM nonce too large",
|
||||
args: []string{"vtpm", "--vtpm", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce+1))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
|
||||
},
|
||||
{
|
||||
name: "Token nonce too large",
|
||||
args: []string{"azure-token", "--token", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce+1))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
|
||||
},
|
||||
{
|
||||
name: "successful TDX attestation",
|
||||
args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
sdk.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(nil).Run(func(args mock.Arguments) {
|
||||
if _, err := args.Get(4).(*os.File).Write([]byte("mock tdx attestation")); err != nil {
|
||||
t.Fatalf("Failed to write to attestation file: %v", err)
|
||||
}
|
||||
})
|
||||
},
|
||||
expectedOut: "Fetching TDX attestation report",
|
||||
},
|
||||
{
|
||||
name: "file creation error",
|
||||
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "Error creating attestation file",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
os.Remove(attestationFilePath)
|
||||
os.Remove(azureAttestResultFilePath)
|
||||
os.Remove(azureAttestTokenFilePath)
|
||||
defer func() {
|
||||
os.Remove(attestationFilePath)
|
||||
os.Remove(azureAttestResultFilePath)
|
||||
os.Remove(azureAttestTokenFilePath)
|
||||
}()
|
||||
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
tc.setupMock(mockSDK)
|
||||
|
||||
if tc.name == "file creation error" {
|
||||
err := os.Mkdir(attestationFilePath, 0o755)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(attestationFilePath)
|
||||
}
|
||||
|
||||
cmd := cli.NewGetAttestationCmd()
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
cmd.SetArgs(tc.args)
|
||||
|
||||
err := cmd.Execute()
|
||||
output := buf.String()
|
||||
|
||||
if tc.expectedErr != "" {
|
||||
assert.Contains(t, output, tc.expectedErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tc.expectedOut != "" {
|
||||
assert.Contains(t, output, tc.expectedOut)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileOperations(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
t.Run("openInputFile", func(t *testing.T) {
|
||||
attestationFile = ""
|
||||
reader, err := openInputFile()
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, errEmptyFile, err)
|
||||
assert.Nil(t, reader)
|
||||
|
||||
tempFile := createTempFile(t, []byte("test content"))
|
||||
defer os.Remove(tempFile)
|
||||
attestationFile = tempFile
|
||||
reader, err = openInputFile()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, reader)
|
||||
if file, ok := reader.(*os.File); ok {
|
||||
file.Close()
|
||||
}
|
||||
|
||||
attestationFile = "non-existent-file.bin"
|
||||
reader, err = openInputFile()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, reader)
|
||||
})
|
||||
|
||||
t.Run("createOutputFile", func(t *testing.T) {
|
||||
output = ""
|
||||
writer, err := createOutputFile()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, os.Stdout, writer)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
output = filepath.Join(tempDir, "test_output.txt")
|
||||
writer, err = createOutputFile()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, writer)
|
||||
if file, ok := writer.(*os.File); ok {
|
||||
file.Close()
|
||||
}
|
||||
|
||||
output = "/invalid/path/file.txt"
|
||||
writer, err = createOutputFile()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, writer)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidationFunctions(t *testing.T) {
|
||||
t.Run("validateFieldLength", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldName string
|
||||
field []byte
|
||||
expectedLength int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "nil field",
|
||||
fieldName: "test",
|
||||
field: nil,
|
||||
expectedLength: 32,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "correct length",
|
||||
fieldName: "test",
|
||||
field: make([]byte, 32),
|
||||
expectedLength: 32,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "incorrect length",
|
||||
fieldName: "test",
|
||||
field: make([]byte, 16),
|
||||
expectedLength: 32,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateFieldLength(tt.fieldName, tt.field, tt.expectedLength)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.fieldName)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeJWTToJSONEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte(""),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "single part",
|
||||
input: []byte("onlyonepart"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid base64 in header",
|
||||
input: []byte("invalid@base64.validpart"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid base64 in payload",
|
||||
input: []byte("eyJhbGciOiJIUzI1NiJ9.invalid@base64"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON in header",
|
||||
input: []byte("bm90anNvbg.eyJzdWIiOiJ0ZXN0In0"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON in payload",
|
||||
input: []byte("eyJhbGciOiJIUzI1NiJ9.bm90anNvbg"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "valid JWT with padding",
|
||||
input: []byte("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature"),
|
||||
expected: `{
|
||||
"header": {
|
||||
"alg": "HS256"
|
||||
},
|
||||
"payload": {
|
||||
"sub": "test"
|
||||
}
|
||||
}`,
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := decodeJWTToJSON(tt.input)
|
||||
if tt.hasError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
if tt.expected != "" {
|
||||
assert.JSONEq(t, tt.expected, string(result))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeasureCmdEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
mockSetup func(*mmocks.MeasurementProvider)
|
||||
expectedError string
|
||||
expectedOut string
|
||||
}{
|
||||
{
|
||||
name: "no arguments",
|
||||
args: []string{},
|
||||
mockSetup: func(m *mmocks.MeasurementProvider) {
|
||||
},
|
||||
expectedError: "requires at least 1 arg(s), only received 0",
|
||||
},
|
||||
{
|
||||
name: "single line output success",
|
||||
args: []string{"test.igvm"},
|
||||
mockSetup: func(m *mmocks.MeasurementProvider) {
|
||||
m.On("Run", "test.igvm").Return([]byte("ABCDEF123456"), nil)
|
||||
},
|
||||
expectedOut: "",
|
||||
},
|
||||
{
|
||||
name: "multi-line output error",
|
||||
args: []string{"test.igvm"},
|
||||
mockSetup: func(m *mmocks.MeasurementProvider) {
|
||||
m.On("Run", "test.igvm").Return([]byte("line1\nline2\nERROR: something went wrong"), nil)
|
||||
},
|
||||
expectedError: "ERROR: something went wrong",
|
||||
},
|
||||
{
|
||||
name: "measurement run error",
|
||||
args: []string{"test.igvm"},
|
||||
mockSetup: func(m *mmocks.MeasurementProvider) {
|
||||
m.On("Run", "test.igvm").Return(nil, errors.New("measurement failed"))
|
||||
},
|
||||
expectedError: "measurement failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMeasurement := new(mmocks.MeasurementProvider)
|
||||
tt.mockSetup(mockMeasurement)
|
||||
|
||||
cli := &CLI{measurement: mockMeasurement}
|
||||
cmd := cli.NewMeasureCmd("fake_binary_path")
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
err := cmd.Execute()
|
||||
|
||||
if tt.expectedError != "" {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectedError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedOut != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedOut)
|
||||
}
|
||||
}
|
||||
|
||||
mockMeasurement.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAttestationValidationCmdPreRunE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
flags map[string]string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "no file path provided",
|
||||
args: []string{},
|
||||
flags: map[string]string{"mode": "snp"},
|
||||
expectedErr: "please pass the attestation report file path",
|
||||
},
|
||||
{
|
||||
name: "multiple file paths",
|
||||
args: []string{"file1.bin", "file2.bin"},
|
||||
flags: map[string]string{"mode": "snp"},
|
||||
expectedErr: "please pass the attestation report file path",
|
||||
},
|
||||
{
|
||||
name: "unknown mode",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "unknown"},
|
||||
expectedErr: "unknown mode: unknown",
|
||||
},
|
||||
{
|
||||
name: "SNP mode missing report_data",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "snp"},
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "SNP mode missing product",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "snp", "report_data": "123"},
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "vTPM mode missing nonce",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "vtpm"},
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "SNP-vTPM mode missing required flags",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "snp-vtpm"},
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "TDX mode missing report_data",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "tdx"},
|
||||
expectedErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
for key, value := range tt.flags {
|
||||
if err := cmd.Flags().Set(key, value); err != nil {
|
||||
}
|
||||
}
|
||||
|
||||
err := cmd.PreRunE(cmd, tt.args)
|
||||
if tt.expectedErr != "" {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectedErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudProviderConfigurations(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cloud string
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "none cloud provider",
|
||||
cloud: CCNone,
|
||||
expectedType: "vtpm",
|
||||
},
|
||||
{
|
||||
name: "azure cloud provider",
|
||||
cloud: CCAzure,
|
||||
expectedType: "azure",
|
||||
},
|
||||
{
|
||||
name: "gcp cloud provider",
|
||||
cloud: CCGCP,
|
||||
expectedType: "vtpm",
|
||||
},
|
||||
{
|
||||
name: "default cloud provider",
|
||||
cloud: "",
|
||||
expectedType: "vtpm",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
if err := cmd.Flags().Set("cloud", tt.cloud); err != nil {
|
||||
t.Fatalf("Failed to set cloud flag: %v", err)
|
||||
}
|
||||
cloud, _ := cmd.Flags().GetString("cloud")
|
||||
assert.Equal(t, tt.cloud, cloud)
|
||||
|
||||
assert.Contains(t, cmd.Short, tt.cloud)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileOperationErrors(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
t.Run("file close error handling", func(t *testing.T) {
|
||||
tempFile := createTempFile(t, []byte("test content"))
|
||||
defer os.Remove(tempFile)
|
||||
|
||||
assert.True(t, true)
|
||||
})
|
||||
|
||||
t.Run("file write error handling", func(t *testing.T) {
|
||||
tempFile := createTempFile(t, []byte("test content"))
|
||||
defer os.Remove(tempFile)
|
||||
|
||||
err := os.Chmod(tempFile, 0o444)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(tempFile, []byte("new content"), 0o644)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("file read error handling", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
_, err := os.ReadFile(tempDir)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mockSDK.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(context.Canceled)
|
||||
|
||||
cmd := cli.NewGetAttestationCmd()
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
|
||||
cmd.SetArgs([]string{"snp", "--tee", teeNonceHex})
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, buf.String(), "Failed to get attestation due to error")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
||||
)
|
||||
|
||||
func TestCLI_NewIMAMeasurementsCmd(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
connectErr error
|
||||
mockIMAData string
|
||||
mockError error
|
||||
expectedFilename string
|
||||
expectedOutput []string
|
||||
expectedError []string
|
||||
shouldCreateFile bool
|
||||
fileCreationError bool
|
||||
invalidDigestData bool
|
||||
setupCustomFile func(filename string) error
|
||||
}{
|
||||
{
|
||||
name: "successful_retrieval_default_filename",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedFilename: imaMeasurementsFilename,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "PCR10 = 0000000000000000000000000000000000000000", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "successful_retrieval_custom_filename",
|
||||
args: []string{"custom_ima_file.txt"},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedFilename: "custom_ima_file.txt",
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "custom_ima_file.txt", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "connection_error",
|
||||
args: []string{},
|
||||
connectErr: fmt.Errorf("connection failed"),
|
||||
expectedError: []string{"Failed to connect to agent: connection failed ❌"},
|
||||
},
|
||||
{
|
||||
name: "file_creation_error",
|
||||
args: []string{"/invalid/path/file.txt"},
|
||||
connectErr: nil,
|
||||
fileCreationError: true,
|
||||
expectedError: []string{"Error creating imaMeasurements file:"},
|
||||
},
|
||||
{
|
||||
name: "sdk_error",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockError: fmt.Errorf("SDK communication failed"),
|
||||
expectedError: []string{"Error retrieving Linux IMA measurements file: SDK communication failed ❌"},
|
||||
},
|
||||
{
|
||||
name: "verification_failure_wrong_pcr",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "10 9999999999999999999999999999999999999999 ima-ng sha1:0000000000000000000000000000000000000000 /usr/bin/test",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully"},
|
||||
expectedError: []string{"Measurements file not verified ❌"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "empty_measurements_file",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "measurements_with_non_pcr10_entries",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "measurements_with_zero_digest_replacement",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
|
||||
cli := &CLI{
|
||||
agentSDK: mockSDK,
|
||||
connectErr: tc.connectErr,
|
||||
}
|
||||
|
||||
if tc.connectErr == nil && !tc.fileCreationError {
|
||||
mockSDK.On("IMAMeasurements", mock.Anything, mock.Anything).Return([]byte(tc.mockIMAData), tc.mockError)
|
||||
}
|
||||
|
||||
cmd := cli.NewIMAMeasurementsCmd()
|
||||
|
||||
var output bytes.Buffer
|
||||
cmd.SetOut(&output)
|
||||
cmd.SetErr(&output)
|
||||
|
||||
expectedFilename := tc.expectedFilename
|
||||
if expectedFilename == "" {
|
||||
if len(tc.args) > 0 {
|
||||
expectedFilename = tc.args[0]
|
||||
} else {
|
||||
expectedFilename = imaMeasurementsFilename
|
||||
}
|
||||
}
|
||||
|
||||
if tc.setupCustomFile != nil {
|
||||
err := tc.setupCustomFile(expectedFilename)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cmd.SetArgs(tc.args)
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err, "Command execution failed")
|
||||
|
||||
outputStr := output.String()
|
||||
|
||||
for _, expectedMsg := range tc.expectedOutput {
|
||||
assert.Contains(t, outputStr, expectedMsg, "Expected output message not found")
|
||||
}
|
||||
|
||||
for _, expectedErr := range tc.expectedError {
|
||||
assert.Contains(t, outputStr, expectedErr, "Expected error message not found")
|
||||
}
|
||||
|
||||
if tc.shouldCreateFile && tc.connectErr == nil && !tc.fileCreationError && tc.mockError == nil {
|
||||
if _, err := os.Stat(expectedFilename); err == nil {
|
||||
os.Remove(expectedFilename)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.connectErr == nil && !tc.fileCreationError {
|
||||
mockSDK.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+10
-6
@@ -38,9 +38,11 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
Example: `create-vm`,
|
||||
Args: cobra.ExactArgs(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
if c.managerClient == nil || c.connectErr != nil {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
@@ -74,7 +76,7 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
cmd.Flags().StringVar(&agentCVMServerCA, serverCA, "", "CVM server CA")
|
||||
cmd.Flags().StringVar(&agentCVMClientKey, clientKey, "", "CVM client key")
|
||||
cmd.Flags().StringVar(&agentCVMClientCrt, clientCrt, "", "CVM client crt")
|
||||
cmd.Flags().StringVar(&agentCVMCaUrl, agentCVMCaUrl, "", "CVM CA service URL")
|
||||
cmd.Flags().StringVar(&agentCVMCaUrl, caUrl, "", "CVM CA service URL")
|
||||
cmd.Flags().StringVar(&agentLogLevel, logLevel, "", "Agent Log level")
|
||||
cmd.Flags().DurationVar(&ttl, ttlFlag, 0, "TTL for the VM")
|
||||
if err := cmd.MarkFlagRequired(serverURL); err != nil {
|
||||
@@ -92,8 +94,10 @@ func (c *CLI) NewRemoveVMCmd() *cobra.Command {
|
||||
Example: `remove-vm <cvm_id>`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := c.InitializeManagerClient(cmd); err == nil {
|
||||
defer c.Close()
|
||||
if c.managerClient == nil || c.connectErr != nil {
|
||||
if err := c.InitializeManagerClient(cmd); err == nil {
|
||||
defer c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if c.connectErr != nil {
|
||||
|
||||
@@ -0,0 +1,600 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/mocks"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
func TestCLI_NewCreateVMCmd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*mocks.ManagerServiceClient)
|
||||
setupCLI func(*CLI)
|
||||
setupFiles func(string) error
|
||||
flags map[string]string
|
||||
expectedOutput string
|
||||
expectedError string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful VM creation with all flags",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.MatchedBy(func(req *manager.CreateReq) bool {
|
||||
return req.AgentCvmServerUrl == "https://server.com" &&
|
||||
req.AgentLogLevel == "debug" &&
|
||||
req.AgentCvmCaUrl == "https://ca.com" &&
|
||||
req.Ttl == "1h0m0s" &&
|
||||
string(req.AgentCvmServerCaCert) == "ca-cert-content" &&
|
||||
string(req.AgentCvmClientKey) == "client-key-content" &&
|
||||
string(req.AgentCvmClientCert) == "client-cert-content"
|
||||
})).Return(&manager.CreateRes{
|
||||
CvmId: "vm-123",
|
||||
ForwardedPort: "8080",
|
||||
}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"server-ca.pem": "ca-cert-content",
|
||||
"client-key.pem": "client-key-content",
|
||||
"client-crt.pem": "client-cert-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
"server-ca": "server-ca.pem",
|
||||
"client-key": "client-key.pem",
|
||||
"client-crt": "client-crt.pem",
|
||||
"ca-url": "https://ca.com",
|
||||
"log-level": "debug",
|
||||
"ttl": "1h",
|
||||
},
|
||||
expectedOutput: "✅ Virtual machine created successfully with id vm-123 and port 8080",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "successful VM creation with minimal flags",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.MatchedBy(func(req *manager.CreateReq) bool {
|
||||
return req.AgentCvmServerUrl == "https://server.com" &&
|
||||
req.AgentLogLevel == "" &&
|
||||
req.AgentCvmCaUrl == "" &&
|
||||
req.Ttl == "" &&
|
||||
len(req.AgentCvmServerCaCert) == 0 &&
|
||||
len(req.AgentCvmClientKey) == 0 &&
|
||||
len(req.AgentCvmClientCert) == 0
|
||||
})).Return(&manager.CreateRes{
|
||||
CvmId: "vm-456",
|
||||
ForwardedPort: "9090",
|
||||
}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // No files needed for minimal test
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedOutput: "✅ Virtual machine created successfully with id vm-456 and port 9090",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "manager client initialization failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as initialization fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
cli.connectErr = errors.New("connection failed")
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "certificate loading failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as cert loading fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // Don't create the cert file
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
"server-ca": "nonexistent-ca.pem",
|
||||
},
|
||||
expectedError: "Error loading certs:",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "CreateVm API call failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.Anything).Return(nil, errors.New("API error"))
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedError: "Error creating virtual machine: API error ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing required server-url flag",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{}, // No server-url flag
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "cli-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
oldDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
err = os.Chdir(tmpDir)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := os.Chdir(oldDir)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
err = tt.setupFiles(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockClient := new(mocks.ManagerServiceClient)
|
||||
tt.setupMock(mockClient)
|
||||
|
||||
mockCLI := &CLI{
|
||||
managerClient: mockClient,
|
||||
}
|
||||
|
||||
tt.setupCLI(mockCLI)
|
||||
|
||||
cmd := mockCLI.NewCreateVMCmd()
|
||||
|
||||
for flag, value := range tt.flags {
|
||||
err := cmd.Flags().Set(flag, value)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
|
||||
if tt.expectError {
|
||||
if tt.expectedError != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedError)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedOutput != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedOutput)
|
||||
}
|
||||
}
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLI_NewRemoveVMCmd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*mocks.ManagerServiceClient)
|
||||
setupCLI func(*CLI)
|
||||
args []string
|
||||
expectedOutput string
|
||||
expectedError string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful VM removal",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("RemoveVm", mock.Anything, &manager.RemoveReq{
|
||||
CvmId: "vm-123",
|
||||
}).Return(&emptypb.Empty{}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-123"},
|
||||
expectedOutput: "✅ Virtual machine removed successfully",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "manager client initialization failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as initialization fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
cli.connectErr = errors.New("connection failed")
|
||||
},
|
||||
args: []string{"vm-123"},
|
||||
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "RemoveVm API call failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("RemoveVm", mock.Anything, &manager.RemoveReq{
|
||||
CvmId: "vm-456",
|
||||
}).Return(nil, errors.New("removal failed"))
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-456"},
|
||||
expectedError: "Error removing virtual machine: removal failed ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing VM ID argument",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{}, // No VM ID provided
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "too many arguments",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-123", "extra-arg"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockClient := new(mocks.ManagerServiceClient)
|
||||
tt.setupMock(mockClient)
|
||||
|
||||
mockCLI := &CLI{
|
||||
managerClient: mockClient,
|
||||
}
|
||||
tt.setupCLI(mockCLI)
|
||||
|
||||
cmd := mockCLI.NewRemoveVMCmd()
|
||||
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
|
||||
if tt.expectError {
|
||||
if tt.expectedError != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedError)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedOutput != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedOutput)
|
||||
}
|
||||
}
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFile func(string) (string, error)
|
||||
path string
|
||||
expectedResult []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful file read",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
filePath := filepath.Join(tmpDir, "test.txt")
|
||||
err := os.WriteFile(filePath, []byte("test content"), 0o644)
|
||||
return filePath, err
|
||||
},
|
||||
expectedResult: []byte("test content"),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty path returns nil",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
path: "",
|
||||
expectedResult: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent file returns error",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
return filepath.Join(tmpDir, "nonexistent.txt"), nil
|
||||
},
|
||||
expectedResult: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "fileReader-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
filePath, err := tt.setupFile(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.path != "" {
|
||||
filePath = tt.path
|
||||
}
|
||||
|
||||
result, err := fileReader(filePath)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCerts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFiles func(string) error
|
||||
setupGlobal func(string)
|
||||
expectError bool
|
||||
validate func(*testing.T, *manager.CreateReq)
|
||||
}{
|
||||
{
|
||||
name: "successful cert loading with all files",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"client.key": "client-key-content",
|
||||
"client.crt": "client-cert-content",
|
||||
"server.ca": "server-ca-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
agentCVMServerCA = filepath.Join(tmpDir, "server.ca")
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, req *manager.CreateReq) {
|
||||
assert.Equal(t, []byte("client-key-content"), req.AgentCvmClientKey)
|
||||
assert.Equal(t, []byte("client-cert-content"), req.AgentCvmClientCert)
|
||||
assert.Equal(t, []byte("server-ca-content"), req.AgentCvmServerCaCert)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful cert loading with empty paths",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = ""
|
||||
agentCVMClientCrt = ""
|
||||
agentCVMServerCA = ""
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, req *manager.CreateReq) {
|
||||
assert.Nil(t, req.AgentCvmClientKey)
|
||||
assert.Nil(t, req.AgentCvmClientCert)
|
||||
assert.Nil(t, req.AgentCvmServerCaCert)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "client key file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // Don't create client key file
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "nonexistent.key")
|
||||
agentCVMClientCrt = ""
|
||||
agentCVMServerCA = ""
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "client cert file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
// Create client key but not cert
|
||||
return os.WriteFile(filepath.Join(tmpDir, "client.key"), []byte("key-content"), 0o644)
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "nonexistent.crt")
|
||||
agentCVMServerCA = ""
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "server CA file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"client.key": "client-key-content",
|
||||
"client.crt": "client-cert-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
agentCVMServerCA = filepath.Join(tmpDir, "nonexistent.ca")
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "loadCerts-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = tt.setupFiles(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store original global variables
|
||||
origClientKey := agentCVMClientKey
|
||||
origClientCrt := agentCVMClientCrt
|
||||
origServerCA := agentCVMServerCA
|
||||
|
||||
// Setup global variables for test
|
||||
tt.setupGlobal(tmpDir)
|
||||
|
||||
// Restore original values after test
|
||||
defer func() {
|
||||
agentCVMClientKey = origClientKey
|
||||
agentCVMClientCrt = origClientCrt
|
||||
agentCVMServerCA = origServerCA
|
||||
}()
|
||||
|
||||
result, err := loadCerts()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandCreation(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
|
||||
t.Run("create-vm command creation", func(t *testing.T) {
|
||||
cmd := cli.NewCreateVMCmd()
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "create-vm", cmd.Use)
|
||||
assert.Equal(t, "Create a new virtual machine", cmd.Short)
|
||||
|
||||
// Check that required flags are set
|
||||
flag := cmd.Flags().Lookup("server-url")
|
||||
assert.NotNil(t, flag)
|
||||
// Note: We can't easily test MarkFlagRequired in unit tests
|
||||
})
|
||||
|
||||
t.Run("remove-vm command creation", func(t *testing.T) {
|
||||
cmd := cli.NewRemoveVMCmd()
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "remove-vm", cmd.Use)
|
||||
assert.Equal(t, "Remove a virtual machine", cmd.Short)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTTLHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ttlInput string
|
||||
expectedTTL time.Duration
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid duration",
|
||||
ttlInput: "1h30m",
|
||||
expectedTTL: time.Hour + 30*time.Minute,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "zero duration",
|
||||
ttlInput: "0",
|
||||
expectedTTL: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
ttlInput: "",
|
||||
expectedTTL: 0,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockCLI := &CLI{
|
||||
managerClient: new(mocks.ManagerServiceClient),
|
||||
}
|
||||
|
||||
cmd := mockCLI.NewCreateVMCmd()
|
||||
|
||||
if tt.ttlInput != "" {
|
||||
err := cmd.Flags().Set("ttl", tt.ttlInput)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedTTL, ttl)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+3
-1
@@ -62,5 +62,7 @@ func (c *CLI) InitializeManagerClient(cmd *cobra.Command) error {
|
||||
}
|
||||
|
||||
func (c *CLI) Close() {
|
||||
c.client.Close()
|
||||
if c.client != nil {
|
||||
c.client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user