mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
8eb1fac9ad
* Refactor and update dependencies in the project - Updated go.sum to replace `github.com/absmach/magistrala` with `github.com/absmach/supermq` across various modules. - Removed VSock configuration from environment variables and QEMU arguments. - Updated QEMU configuration and related tests to remove references to guest CID and VSock. - Added new HTTP transport layer for API endpoints in the manager. - Introduced Prometheus monitoring configuration with alert rules and Alertmanager setup. - Updated service and VM interfaces to remove unused methods and references. - Refactored tests to align with the new structure and dependencies. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add MaxVMs configuration and enforce limit on VM creation Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for HTTP transport handlers and endpoints Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add test case for exceeding maximum number of VMs in TestRun Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in TestHandlerWithCustomRouter to ensure response writing is checked Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Update dependencies to latest versions - Upgrade cel.dev/expr from v0.23.0 to v0.24.0 - Upgrade github.com/absmach/supermq from v0.16.0 to v0.17.0 - Upgrade github.com/cenkalti/backoff from v4.3.0 to v5.0.2 - Upgrade github.com/cncf/xds/go to v0.0.0-20250501225837-2ac532fd4443 - Upgrade github.com/go-chi/chi/v5 from v5.2.1 to v5.2.2 - Upgrade github.com/go-jose/go-jose/v3 from v3.0.3 to v3.0.4 - Upgrade github.com/gofrs/uuid/v5 from v5.3.0 to v5.3.2 - Upgrade github.com/prometheus/client_golang from v1.22.0 to v1.23.0 - Upgrade github.com/prometheus/client_model from v0.6.1 to v0.6.2 - Upgrade github.com/prometheus/common from v0.62.0 to v0.65.0 - Upgrade github.com/prometheus/procfs from v0.15.1 to v0.16.1 - Upgrade go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp from v0.60.0 to v0.62.0 - Upgrade go.opentelemetry.io/otel/exporters/otlp/otlptrace from v1.36.0 to v1.37.0 - Upgrade golang.org/x/crypto from v0.39.0 to v0.40.0 - Upgrade golang.org/x/sys from v0.33.0 to v0.34.0 - Upgrade golang.org/x/text from v0.26.0 to v0.27.0 - Upgrade golang.org/x/time from v0.11.0 to v0.12.0 - Upgrade google.golang.org/grpc from v1.73.0 to v1.74.2 Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
1390 lines
37 KiB
Go
1390 lines
37 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/absmach/supermq/pkg/errors"
|
|
"github.com/google/go-sev-guest/abi"
|
|
"github.com/google/go-sev-guest/proto/check"
|
|
"github.com/google/go-sev-guest/proto/sevsnp"
|
|
"github.com/spf13/cobra"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
mmocks "github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig/mocks"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
|
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
|
)
|
|
|
|
func TestNewAttestationCmd(t *testing.T) {
|
|
mockSDK := new(mocks.SDK)
|
|
cli := &CLI{agentSDK: mockSDK}
|
|
cmd := cli.NewAttestationCmd()
|
|
|
|
assert.Equal(t, "attestation [command]", cmd.Use)
|
|
assert.Equal(t, "Get and validate attestations", cmd.Short)
|
|
|
|
var buf bytes.Buffer
|
|
cmd.SetOut(&buf)
|
|
|
|
reportData := bytes.Repeat([]byte{0x01}, quoteprovider.Nonce)
|
|
mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(reportData), mock.Anything).Return(nil)
|
|
|
|
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
|
|
err := cmd.Execute()
|
|
assert.NoError(t, err)
|
|
assert.Contains(t, buf.String(), "Get and validate attestations")
|
|
}
|
|
|
|
func TestNewGetAttestationCmd(t *testing.T) {
|
|
validattestation, err := os.ReadFile("../attestation.bin")
|
|
require.NoError(t, err)
|
|
|
|
teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
|
|
vtpmNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce))
|
|
tokenNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce))
|
|
|
|
testCases := []struct {
|
|
name string
|
|
args []string
|
|
mockResponse []byte
|
|
mockError error
|
|
expectedErr string
|
|
expectedOut string
|
|
}{
|
|
{
|
|
name: "successful SNP attestation retrieval",
|
|
args: []string{"snp", "--tee", teeNonce},
|
|
mockResponse: []byte("mock attestation"),
|
|
mockError: nil,
|
|
expectedOut: "Attestation result retrieved and saved successfully!",
|
|
},
|
|
{
|
|
name: "successful vTPM attestation retrieval",
|
|
args: []string{"vtpm", "--vtpm", vtpmNonce},
|
|
mockResponse: []byte("mock attestation"),
|
|
mockError: nil,
|
|
expectedOut: "Attestation result retrieved and saved successfully!",
|
|
},
|
|
{
|
|
name: "successful SNP-vTPM attestation retrieval",
|
|
args: []string{"snp-vtpm", "--tee", teeNonce, "--vtpm", vtpmNonce},
|
|
mockResponse: []byte("mock attestation"),
|
|
mockError: nil,
|
|
expectedOut: "Attestation result retrieved and saved successfully!",
|
|
},
|
|
{
|
|
name: "missing vTPM nonce",
|
|
args: []string{"snp-vtpm", "--tee", teeNonce},
|
|
mockResponse: []byte("mock attestation"),
|
|
mockError: nil,
|
|
expectedOut: "vTPM nonce must be defined for vTPM attestation",
|
|
},
|
|
{
|
|
name: "missing TEE nonce",
|
|
args: []string{"snp-vtpm", "--vtpm", vtpmNonce},
|
|
mockResponse: []byte("mock attestation"),
|
|
mockError: nil,
|
|
expectedOut: "TEE nonce must be defined for SEV-SNP attestation",
|
|
},
|
|
{
|
|
name: "invalid report data size",
|
|
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 65))},
|
|
mockResponse: nil,
|
|
mockError: errors.New("error"),
|
|
expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes",
|
|
},
|
|
{
|
|
name: "invalid vTPM data size",
|
|
args: []string{"vtpm", "--vtpm", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 33))},
|
|
mockResponse: nil,
|
|
mockError: errors.New("error"),
|
|
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
|
|
},
|
|
{
|
|
name: "invalid arguments",
|
|
args: []string{"invalid"},
|
|
mockResponse: nil,
|
|
mockError: errors.New("error"),
|
|
expectedErr: "Bad attestation type: invalid argument ",
|
|
},
|
|
{
|
|
name: "failed to get attestation",
|
|
args: []string{"snp", "--tee", teeNonce},
|
|
mockResponse: nil,
|
|
mockError: errors.New("error"),
|
|
expectedErr: "Failed to get attestation due to error",
|
|
},
|
|
{
|
|
name: "Textproto report error",
|
|
args: []string{"snp", "--tee", teeNonce, "--reporttextproto"},
|
|
mockResponse: []byte("mock attestation"),
|
|
mockError: nil,
|
|
expectedErr: "Fetching SEV-SNP attestation report\nError converting SNP attestation to JSON: attestation contents too small : attestation contents too small (0x10 bytes). Want at least 0x4a0 bytes ❌\n",
|
|
},
|
|
{
|
|
name: "successful Textproto report",
|
|
args: []string{"snp", "--tee", teeNonce, "--reporttextproto"},
|
|
mockResponse: validattestation,
|
|
mockError: nil,
|
|
expectedOut: "Attestation result retrieved and saved successfully!",
|
|
},
|
|
{
|
|
name: "connection error",
|
|
args: []string{"snp", "--tee", teeNonce},
|
|
mockResponse: nil,
|
|
mockError: errors.New("failed to connect to agent"),
|
|
expectedErr: "Failed to connect to agent",
|
|
},
|
|
{
|
|
name: "successful Azure token retrieval",
|
|
args: []string{"azure-token", "--token", tokenNonce},
|
|
mockResponse: []byte("eyJhbGciOiAiUlMyNTYifQ.eyJzdWIiOiAidGVzdC11c2VyIn0.signature"),
|
|
mockError: nil,
|
|
expectedOut: "Fetching Azure token\nAttestation result retrieved and saved successfully!\n",
|
|
},
|
|
{
|
|
name: "failed to retrieve Azure token",
|
|
args: []string{"azure-token", "--token", tokenNonce},
|
|
mockResponse: nil,
|
|
mockError: errors.New("error"),
|
|
expectedErr: "Fetching Azure token\nFailed to get attestation result due to error: error ❌\n",
|
|
},
|
|
{
|
|
name: "invalid token nonce size",
|
|
args: []string{"azure-token", "--token", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 33))},
|
|
mockResponse: nil,
|
|
mockError: errors.New("error"),
|
|
expectedErr: "Fetching Azure token\nvTPM nonce must be a hex encoded string of length lesser or equal 32 bytes ❌ \n",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Cleanup(func() {
|
|
os.Remove(attestationFilePath)
|
|
os.Remove(attestationReportJson)
|
|
})
|
|
mockSDK := new(mocks.SDK)
|
|
cli := &CLI{agentSDK: mockSDK}
|
|
if tc.name == "connection error" {
|
|
cli.connectErr = errors.New("failed to connect to agent")
|
|
}
|
|
cmd := cli.NewGetAttestationCmd()
|
|
var buf bytes.Buffer
|
|
cmd.SetOut(&buf)
|
|
|
|
mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
|
_, err := args.Get(4).(*os.File).Write(tc.mockResponse)
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
mockSDK.On("AttestationResult", mock.Anything, [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
|
_, err := args.Get(3).(*os.File).Write(tc.mockResponse)
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
cmd.SetArgs(tc.args)
|
|
err := cmd.Execute()
|
|
|
|
if tc.expectedErr != "" {
|
|
assert.Contains(t, buf.String(), tc.expectedErr)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Contains(t, buf.String(), tc.expectedOut)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewValidateAttestationValidationCmdDefaults(t *testing.T) {
|
|
cli := &CLI{}
|
|
cmd := cli.NewValidateAttestationValidationCmd()
|
|
|
|
assert.Equal(t, "validate", cmd.Use)
|
|
expectedMessage := fmt.Sprintf("Validate and verify attestation information. You can define the confidential computing cloud provider (%s, %s, %s; %s is the default) and can choose from 4 modes: %s, %s, %s, and %s. Default mode is %s.", CCNone, CCAzure, CCGCP, CCNone, SNP, VTPM, SNPvTPM, TDX, SNP)
|
|
assert.Equal(t, expectedMessage, cmd.Short)
|
|
|
|
assert.Equal(t, fmt.Sprint(defaultMinimumTcb), cmd.Flag("minimum_tcb").Value.String())
|
|
assert.Equal(t, fmt.Sprint(defaultMinimumLaunchTcb), cmd.Flag("minimum_lauch_tcb").Value.String())
|
|
assert.Equal(t, fmt.Sprint(defaultGuestPolicy), cmd.Flag("guest_policy").Value.String())
|
|
assert.Equal(t, fmt.Sprint(defaultMinimumGuestSvn), cmd.Flag("minimum_guest_svn").Value.String())
|
|
assert.Equal(t, fmt.Sprint(defaultMinimumBuild), cmd.Flag("minimum_build").Value.String())
|
|
assert.Equal(t, defaultCheckCrl, cmd.Flag("check_crl").Value.String() == "true")
|
|
assert.Equal(t, fmt.Sprint(defaultTimeout), cmd.Flag("timeout").Value.String())
|
|
assert.Equal(t, fmt.Sprint(defaultMaxRetryDelay), cmd.Flag("max_retry_delay").Value.String())
|
|
}
|
|
|
|
func TestNewValidateAttestationValidationCmd(t *testing.T) {
|
|
cli := &CLI{}
|
|
cmd := cli.NewValidateAttestationValidationCmd()
|
|
|
|
t.Run("missing attestation report file path", func(t *testing.T) {
|
|
err := cmd.Execute()
|
|
assert.Error(t, err)
|
|
assert.Equal(t, "please pass the attestation report file path", err.Error())
|
|
})
|
|
|
|
t.Run("unknown mode", func(t *testing.T) {
|
|
cmd.SetArgs([]string{attestationFilePath, "--mode=invalid"})
|
|
err := cmd.Execute()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "unknown mode")
|
|
})
|
|
|
|
t.Run("snp mode with missing flags", func(t *testing.T) {
|
|
cmd.SetArgs([]string{attestationFilePath, "--mode=snp"})
|
|
err := cmd.Execute()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "required flag(s) \"product\", \"report_data\" not set")
|
|
})
|
|
|
|
t.Run("vtpm mode with missing flags", func(t *testing.T) {
|
|
cmd.SetArgs([]string{vtpmFilePath, "--mode=vtpm"})
|
|
err := cmd.Execute()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "required flag(s) \"format\", \"nonce\", \"output\", \"product\", \"report_data\" not set")
|
|
})
|
|
|
|
t.Run("snp-vtpm mode with missing flags", func(t *testing.T) {
|
|
cmd.SetArgs([]string{vtpmFilePath, "--mode=snp-vtpm"})
|
|
err := cmd.Execute()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "required flag(s) \"format\", \"nonce\", \"output\", \"product\", \"report_data\" not set")
|
|
})
|
|
|
|
t.Run("valid snp mode execution", func(t *testing.T) {
|
|
cli := CLI{}
|
|
cmd := cli.NewValidateAttestationValidationCmd()
|
|
|
|
cmd.RunE = func(_ *cobra.Command, _ []string) error {
|
|
t.Log("Mock RunE executed instead of sevsnpverify")
|
|
return nil
|
|
}
|
|
|
|
cmd.SetArgs([]string{
|
|
"../attestation.bin",
|
|
"--mode=snp",
|
|
"--report_data=" +
|
|
"11223344556677889900aabbccddeeff11223344556677889900aabbccddeeff" +
|
|
"11223344556677889900aabbccddeeff11223344556677889900aabbccddeeff",
|
|
"--product=Milan",
|
|
})
|
|
err := cmd.PreRunE(cmd, []string{"../attestation.bin"})
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("valid vtpm mode execution", func(t *testing.T) {
|
|
cli := CLI{}
|
|
cmd := cli.NewValidateAttestationValidationCmd()
|
|
|
|
cmd.RunE = func(_ *cobra.Command, _ []string) error {
|
|
t.Log("Mock RunE executed instead of vtpmverify")
|
|
return nil
|
|
}
|
|
|
|
cmd.SetArgs([]string{vtpmFilePath, "--mode=vtpm", "--nonce=123abc", "--format=binarypb", "--output=some_output"})
|
|
|
|
err := cmd.PreRunE(cmd, []string{"../quote.dat"})
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("valid snp-vtpm mode execution", func(t *testing.T) {
|
|
cli := CLI{}
|
|
cmd := cli.NewValidateAttestationValidationCmd()
|
|
|
|
cmd.RunE = func(_ *cobra.Command, _ []string) error {
|
|
t.Log("Mock RunE executed instead of vtpmSevSnpverify")
|
|
return nil
|
|
}
|
|
|
|
cmd.SetArgs([]string{vtpmFilePath, "--mode=snp-vtpm", "--nonce=123abc", "--format=textproto", "--output=some_output"})
|
|
err := cmd.PreRunE(cmd, []string{"../quote.dat"})
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestNewMeasureCmd_RunSuccess(t *testing.T) {
|
|
cliInstance := &CLI{}
|
|
mockMeasurement := new(mmocks.MeasurementProvider)
|
|
cliInstance.measurement = mockMeasurement
|
|
|
|
mockMeasurement.On("Run", "testfile.igvm").Return([]byte{}, nil)
|
|
|
|
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
|
|
buf := new(bytes.Buffer)
|
|
cmd.SetOut(buf)
|
|
cmd.SetErr(buf)
|
|
cmd.SetArgs([]string{"testfile.igvm"})
|
|
|
|
err := cmd.Execute()
|
|
|
|
assert.NoError(t, err)
|
|
mockMeasurement.AssertExpectations(t)
|
|
}
|
|
|
|
func TestNewMeasureCmd_RunError(t *testing.T) {
|
|
cliInstance := &CLI{}
|
|
mockMeasurement := new(mmocks.MeasurementProvider)
|
|
cliInstance.measurement = mockMeasurement
|
|
expectedError := errors.New("mocked measurement error")
|
|
|
|
mockMeasurement.On("Run", "testfile.igvm").Return([]byte{}, expectedError)
|
|
|
|
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
|
|
|
|
buf := new(bytes.Buffer)
|
|
cmd.SetOut(buf)
|
|
cmd.SetErr(buf)
|
|
cmd.SetArgs([]string{"testfile.igvm"})
|
|
|
|
err := cmd.Execute()
|
|
|
|
assert.Error(t, err)
|
|
assert.Equal(t, expectedError.Error(), err.Error())
|
|
mockMeasurement.AssertExpectations(t)
|
|
}
|
|
|
|
func TestParseConfig1(t *testing.T) {
|
|
tmpfile, err := os.CreateTemp("", "attestation_policy.json")
|
|
require.NoError(t, err)
|
|
defer os.Remove(tmpfile.Name())
|
|
|
|
cfgString = ""
|
|
|
|
err = parseConfig()
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, cfg.RootOfTrust)
|
|
assert.NotNil(t, cfg.Policy)
|
|
|
|
cfgString = tmpfile.Name()
|
|
|
|
_, err = tmpfile.WriteString(`{"rootOfTrust":{"product":"test_product"},"policy":{"minimumGuestSvn":1}}`)
|
|
require.NoError(t, err)
|
|
err = parseConfig()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "test_product", cfg.RootOfTrust.Product)
|
|
assert.Equal(t, uint32(1), cfg.Policy.MinimumGuestSvn)
|
|
|
|
_, err = tmpfile.WriteString(`{"invalid_json"`)
|
|
require.NoError(t, err)
|
|
err = parseConfig()
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestParseHashes1(t *testing.T) {
|
|
trustedAuthorHashes = []string{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}
|
|
trustedIdKeyHashes = []string{"fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"}
|
|
|
|
cfg = check.Config{}
|
|
if cfg.Policy == nil {
|
|
cfg.Policy = &check.Policy{}
|
|
}
|
|
|
|
err := parseHashes()
|
|
assert.NoError(t, err)
|
|
assert.Len(t, cfg.Policy.TrustedAuthorKeyHashes, 1)
|
|
assert.Len(t, cfg.Policy.TrustedIdKeyHashes, 1)
|
|
|
|
trustedAuthorHashes = []string{"invalid_hash"}
|
|
err = parseHashes()
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestParseFiles(t *testing.T) {
|
|
attestationFile = "test_attestation.bin"
|
|
authorKeyFile := "test_author_key.pem"
|
|
idKeyFile := "test_id_key.pem"
|
|
|
|
err := os.WriteFile(attestationFile, []byte("test attestation"), 0o644)
|
|
assert.NoError(t, err)
|
|
err = os.WriteFile(authorKeyFile, []byte("test author key"), 0o644)
|
|
assert.NoError(t, err)
|
|
err = os.WriteFile(idKeyFile, []byte("test id key"), 0o644)
|
|
assert.NoError(t, err)
|
|
|
|
trustedAuthorKeys = []string{authorKeyFile}
|
|
trustedIdKeys = []string{idKeyFile}
|
|
|
|
err = parseAttestationFile()
|
|
assert.NoError(t, err)
|
|
err = parseTrustedKeys()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []byte("test attestation"), attestationRaw)
|
|
assert.Len(t, cfg.Policy.TrustedAuthorKeys, 1)
|
|
assert.Len(t, cfg.Policy.TrustedIdKeys, 1)
|
|
|
|
os.Remove(attestationFile)
|
|
os.Remove(authorKeyFile)
|
|
os.Remove(idKeyFile)
|
|
|
|
attestationFile = "non_existent_file.bin"
|
|
err = parseAttestationFile()
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestParseUints1(t *testing.T) {
|
|
stepping = "10"
|
|
platformInfo = "0xFF"
|
|
|
|
cfg = check.Config{}
|
|
if cfg.Policy == nil {
|
|
cfg.Policy = &check.Policy{
|
|
Product: &sevsnp.SevProduct{},
|
|
}
|
|
}
|
|
err := parseUints()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, uint32(10), cfg.Policy.Product.MachineStepping.Value)
|
|
assert.Equal(t, uint64(255), cfg.Policy.PlatformInfo.Value)
|
|
|
|
stepping = "invalid"
|
|
err = parseUints()
|
|
assert.Error(t, err)
|
|
|
|
stepping = "10"
|
|
platformInfo = "invalid"
|
|
err = parseUints()
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestValidateInput1(t *testing.T) {
|
|
cfg = check.Config{}
|
|
if cfg.Policy == nil {
|
|
cfg.Policy = &check.Policy{}
|
|
}
|
|
if cfg.RootOfTrust == nil {
|
|
cfg.RootOfTrust = &check.RootOfTrust{}
|
|
}
|
|
cfg.Policy.ReportData = make([]byte, 64)
|
|
cfg.Policy.HostData = make([]byte, 32)
|
|
cfg.Policy.FamilyId = make([]byte, 16)
|
|
cfg.Policy.ImageId = make([]byte, 16)
|
|
cfg.Policy.ReportId = make([]byte, 32)
|
|
cfg.Policy.ReportIdMa = make([]byte, 32)
|
|
cfg.Policy.Measurement = make([]byte, 48)
|
|
cfg.Policy.ChipId = make([]byte, 64)
|
|
|
|
err := validateInput()
|
|
assert.NoError(t, err)
|
|
|
|
cfg.Policy.ReportData = make([]byte, 32)
|
|
err = validateInput()
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestGetBase1(t *testing.T) {
|
|
assert.Equal(t, 16, getBase("0xFF"))
|
|
assert.Equal(t, 8, getBase("0o77"))
|
|
assert.Equal(t, 2, getBase("0b1010"))
|
|
assert.Equal(t, 10, getBase("123"))
|
|
}
|
|
|
|
func TestAttestationToJSON(t *testing.T) {
|
|
validReport, err := os.ReadFile("../attestation.bin")
|
|
require.NoError(t, err)
|
|
tests := []struct {
|
|
name string
|
|
input []byte
|
|
err error
|
|
}{
|
|
{
|
|
name: "Valid report",
|
|
input: validReport,
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Invalid report size",
|
|
input: make([]byte, abi.ReportSize-1),
|
|
err: errReportSize,
|
|
},
|
|
{
|
|
name: "Nil input",
|
|
input: nil,
|
|
err: errReportSize,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := attesationToJSON(tt.input)
|
|
assert.True(t, errors.Contains(err, tt.err))
|
|
if tt.err != nil {
|
|
assert.Nil(t, got)
|
|
return
|
|
}
|
|
|
|
require.NotNil(t, got)
|
|
|
|
var js map[string]interface{}
|
|
err = json.Unmarshal(got, &js)
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAttestationFromJSON(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input []byte
|
|
err error
|
|
validate func(t *testing.T, output []byte)
|
|
}{
|
|
{
|
|
name: "Valid JSON",
|
|
input: func() []byte {
|
|
att := &sevsnp.Attestation{
|
|
Report: &sevsnp.Report{
|
|
CurrentTcb: 1,
|
|
FamilyId: make([]byte, 16),
|
|
ImageId: make([]byte, 16),
|
|
ReportData: make([]byte, 64),
|
|
Measurement: make([]byte, 48),
|
|
HostData: make([]byte, 32),
|
|
IdKeyDigest: make([]byte, 48),
|
|
AuthorKeyDigest: make([]byte, 48),
|
|
ReportId: make([]byte, 32),
|
|
ReportIdMa: make([]byte, 32),
|
|
ChipId: make([]byte, 64),
|
|
Signature: make([]byte, 512),
|
|
},
|
|
}
|
|
data, err := json.Marshal(att)
|
|
require.NoError(t, err)
|
|
return data
|
|
}(),
|
|
err: nil,
|
|
validate: func(t *testing.T, output []byte) {
|
|
assert.NotEmpty(t, output)
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid JSON",
|
|
input: []byte(`{"invalid": json`),
|
|
err: errors.New("invalid character 'j' looking for beginning of value"),
|
|
validate: func(t *testing.T, output []byte) {
|
|
assert.Nil(t, output)
|
|
},
|
|
},
|
|
{
|
|
name: "Empty input",
|
|
input: []byte{},
|
|
err: errors.New("unexpected end of JSON input"),
|
|
validate: func(t *testing.T, output []byte) {
|
|
assert.Nil(t, output)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := attesationFromJSON(tt.input)
|
|
assert.True(t, errors.Contains(err, tt.err))
|
|
tt.validate(t, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsFileJSON(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
filename string
|
|
want bool
|
|
}{
|
|
{
|
|
name: "Valid JSON extension",
|
|
filename: "test.json",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "Valid JSON extension with path",
|
|
filename: "/path/to/test.json",
|
|
want: true,
|
|
},
|
|
{
|
|
name: "Invalid extension",
|
|
filename: "test.txt",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "No extension",
|
|
filename: "test",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "JSON in filename",
|
|
filename: "json.txt",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "Empty string",
|
|
filename: "",
|
|
want: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := isFileJSON(tt.filename)
|
|
assert.Equal(t, tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRoundTrip(t *testing.T) {
|
|
originalReport, err := os.ReadFile("../attestation.bin")
|
|
require.NoError(t, err)
|
|
jsonData, err := attesationToJSON(originalReport)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, jsonData)
|
|
|
|
roundTripReport, err := attesationFromJSON(jsonData)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, roundTripReport)
|
|
}
|
|
|
|
func TestDecodeJWTToJSON(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input []byte
|
|
err error
|
|
validate func(t *testing.T, output []byte)
|
|
}{
|
|
{
|
|
name: "Valid JWT",
|
|
input: []byte("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." +
|
|
"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." +
|
|
"SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"),
|
|
err: nil,
|
|
validate: func(t *testing.T, output []byte) {
|
|
assert.NotEmpty(t, output)
|
|
assert.Contains(t, string(output), `"header"`)
|
|
assert.Contains(t, string(output), `"payload"`)
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid JWT - one part",
|
|
input: []byte("justonepart"),
|
|
err: fmt.Errorf("invalid JWT: must have at least 2 parts"),
|
|
validate: func(t *testing.T, output []byte) {
|
|
assert.Nil(t, output)
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid Base64",
|
|
input: []byte("bad@@@.header"),
|
|
err: errors.New("illegal base64 data"),
|
|
validate: func(t *testing.T, output []byte) {
|
|
assert.Nil(t, output)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := decodeJWTToJSON(tt.input)
|
|
|
|
if tt.err != nil {
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), tt.err.Error())
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
tt.validate(t, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func setupTestEnvironment() func() {
|
|
originalMode := mode
|
|
originalCfgString := cfgString
|
|
originalTimeout := timeout
|
|
originalMaxRetryDelay := maxRetryDelay
|
|
originalPlatformInfo := platformInfo
|
|
originalStepping := stepping
|
|
originalTrustedAuthorKeys := trustedAuthorKeys
|
|
originalTrustedAuthorHashes := trustedAuthorHashes
|
|
originalTrustedIdKeys := trustedIdKeys
|
|
originalTrustedIdKeyHashes := trustedIdKeyHashes
|
|
originalAttestationFile := attestationFile
|
|
originalAttestationRaw := attestationRaw
|
|
originalOutput := output
|
|
originalNonce := nonce
|
|
originalFormat := format
|
|
originalTeeNonce := teeNonce
|
|
originalTokenNonce := tokenNonce
|
|
originalGetTextProtoAttestationReport := getTextProtoAttestationReport
|
|
originalGetAzureTokenJWT := getAzureTokenJWT
|
|
originalCloud := cloud
|
|
originalReportData := reportData
|
|
originalCheckCrl := checkCrl
|
|
|
|
mode = ""
|
|
cfgString = ""
|
|
timeout = 0
|
|
maxRetryDelay = 0
|
|
platformInfo = ""
|
|
stepping = ""
|
|
trustedAuthorKeys = []string{}
|
|
trustedAuthorHashes = []string{}
|
|
trustedIdKeys = []string{}
|
|
trustedIdKeyHashes = []string{}
|
|
attestationFile = ""
|
|
attestationRaw = []byte{}
|
|
output = ""
|
|
nonce = []byte{}
|
|
format = ""
|
|
teeNonce = []byte{}
|
|
tokenNonce = []byte{}
|
|
getTextProtoAttestationReport = false
|
|
getAzureTokenJWT = false
|
|
cloud = ""
|
|
reportData = []byte{}
|
|
checkCrl = false
|
|
|
|
return func() {
|
|
mode = originalMode
|
|
cfgString = originalCfgString
|
|
timeout = originalTimeout
|
|
maxRetryDelay = originalMaxRetryDelay
|
|
platformInfo = originalPlatformInfo
|
|
stepping = originalStepping
|
|
trustedAuthorKeys = originalTrustedAuthorKeys
|
|
trustedAuthorHashes = originalTrustedAuthorHashes
|
|
trustedIdKeys = originalTrustedIdKeys
|
|
trustedIdKeyHashes = originalTrustedIdKeyHashes
|
|
attestationFile = originalAttestationFile
|
|
attestationRaw = originalAttestationRaw
|
|
output = originalOutput
|
|
nonce = originalNonce
|
|
format = originalFormat
|
|
teeNonce = originalTeeNonce
|
|
tokenNonce = originalTokenNonce
|
|
getTextProtoAttestationReport = originalGetTextProtoAttestationReport
|
|
getAzureTokenJWT = originalGetAzureTokenJWT
|
|
cloud = originalCloud
|
|
reportData = originalReportData
|
|
checkCrl = originalCheckCrl
|
|
}
|
|
}
|
|
|
|
func createTempFile(t *testing.T, content []byte) string {
|
|
tmpfile, err := os.CreateTemp("", "test_*.bin")
|
|
require.NoError(t, err)
|
|
defer tmpfile.Close()
|
|
|
|
_, err = tmpfile.Write(content)
|
|
require.NoError(t, err)
|
|
|
|
return tmpfile.Name()
|
|
}
|
|
|
|
func TestNewAttestationCmdEdgeCases(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
args []string
|
|
expectedOutput string
|
|
hasSubcommands bool
|
|
}{
|
|
{
|
|
name: "no arguments shows help",
|
|
args: []string{},
|
|
expectedOutput: "Get and validate attestations",
|
|
hasSubcommands: true,
|
|
},
|
|
{
|
|
name: "help flag shows usage",
|
|
args: []string{"--help"},
|
|
expectedOutput: "Get and validate attestations",
|
|
hasSubcommands: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockSDK := new(mocks.SDK)
|
|
cli := &CLI{agentSDK: mockSDK}
|
|
cmd := cli.NewAttestationCmd()
|
|
|
|
var buf bytes.Buffer
|
|
cmd.SetOut(&buf)
|
|
cmd.SetErr(&buf)
|
|
cmd.SetArgs(tt.args)
|
|
|
|
err := cmd.Execute()
|
|
assert.NoError(t, err)
|
|
|
|
output := buf.String()
|
|
assert.Contains(t, output, tt.expectedOutput)
|
|
|
|
if tt.hasSubcommands {
|
|
assert.Contains(t, output, "Get and validate attestations")
|
|
assert.Contains(t, output, "Usage:")
|
|
assert.Contains(t, output, "Flags:")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetAttestationCmdEdgeCases(t *testing.T) {
|
|
defer setupTestEnvironment()()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
args []string
|
|
setupMock func(*mocks.SDK)
|
|
expectedErr string
|
|
expectedOut string
|
|
}{
|
|
{
|
|
name: "no arguments provided",
|
|
args: []string{},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "accepts 1 arg(s), received 0",
|
|
},
|
|
{
|
|
name: "too many arguments",
|
|
args: []string{"snp", "extra"},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "accepts 1 arg(s), received 2",
|
|
},
|
|
{
|
|
name: "invalid attestation type",
|
|
args: []string{"invalid-type"},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "Bad attestation type",
|
|
},
|
|
{
|
|
name: "SNP with missing TEE nonce",
|
|
args: []string{"snp"},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "TEE nonce must be defined for SEV-SNP attestation",
|
|
},
|
|
{
|
|
name: "vTPM with missing nonce",
|
|
args: []string{"vtpm"},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "vTPM nonce must be defined for vTPM attestation",
|
|
},
|
|
{
|
|
name: "Azure token with missing token nonce",
|
|
args: []string{"azure-token"},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "Token nonce must be defined for Azure attestation",
|
|
},
|
|
{
|
|
name: "TEE nonce too large",
|
|
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce+1))},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes",
|
|
},
|
|
{
|
|
name: "vTPM nonce too large",
|
|
args: []string{"vtpm", "--vtpm", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce+1))},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
|
|
},
|
|
{
|
|
name: "Token nonce too large",
|
|
args: []string{"azure-token", "--token", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce+1))},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
|
|
},
|
|
{
|
|
name: "successful TDX attestation",
|
|
args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
sdk.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
|
Return(nil).Run(func(args mock.Arguments) {
|
|
if _, err := args.Get(4).(*os.File).Write([]byte("mock tdx attestation")); err != nil {
|
|
t.Fatalf("Failed to write to attestation file: %v", err)
|
|
}
|
|
})
|
|
},
|
|
expectedOut: "Fetching TDX attestation report",
|
|
},
|
|
{
|
|
name: "file creation error",
|
|
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
|
|
setupMock: func(sdk *mocks.SDK) {
|
|
},
|
|
expectedErr: "Error creating attestation file",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
os.Remove(attestationFilePath)
|
|
os.Remove(azureAttestResultFilePath)
|
|
os.Remove(azureAttestTokenFilePath)
|
|
defer func() {
|
|
os.Remove(attestationFilePath)
|
|
os.Remove(azureAttestResultFilePath)
|
|
os.Remove(azureAttestTokenFilePath)
|
|
}()
|
|
|
|
mockSDK := new(mocks.SDK)
|
|
cli := &CLI{agentSDK: mockSDK}
|
|
tc.setupMock(mockSDK)
|
|
|
|
if tc.name == "file creation error" {
|
|
err := os.Mkdir(attestationFilePath, 0o755)
|
|
require.NoError(t, err)
|
|
defer os.Remove(attestationFilePath)
|
|
}
|
|
|
|
cmd := cli.NewGetAttestationCmd()
|
|
var buf bytes.Buffer
|
|
cmd.SetOut(&buf)
|
|
cmd.SetErr(&buf)
|
|
cmd.SetArgs(tc.args)
|
|
|
|
err := cmd.Execute()
|
|
output := buf.String()
|
|
|
|
if tc.expectedErr != "" {
|
|
assert.Contains(t, output, tc.expectedErr)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
if tc.expectedOut != "" {
|
|
assert.Contains(t, output, tc.expectedOut)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFileOperations(t *testing.T) {
|
|
defer setupTestEnvironment()()
|
|
|
|
t.Run("openInputFile", func(t *testing.T) {
|
|
attestationFile = ""
|
|
reader, err := openInputFile()
|
|
assert.Error(t, err)
|
|
assert.Equal(t, errEmptyFile, err)
|
|
assert.Nil(t, reader)
|
|
|
|
tempFile := createTempFile(t, []byte("test content"))
|
|
defer os.Remove(tempFile)
|
|
attestationFile = tempFile
|
|
reader, err = openInputFile()
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, reader)
|
|
if file, ok := reader.(*os.File); ok {
|
|
file.Close()
|
|
}
|
|
|
|
attestationFile = "non-existent-file.bin"
|
|
reader, err = openInputFile()
|
|
assert.Error(t, err)
|
|
assert.Nil(t, reader)
|
|
})
|
|
|
|
t.Run("createOutputFile", func(t *testing.T) {
|
|
output = ""
|
|
writer, err := createOutputFile()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, os.Stdout, writer)
|
|
|
|
tempDir := t.TempDir()
|
|
output = filepath.Join(tempDir, "test_output.txt")
|
|
writer, err = createOutputFile()
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, writer)
|
|
if file, ok := writer.(*os.File); ok {
|
|
file.Close()
|
|
}
|
|
|
|
output = "/invalid/path/file.txt"
|
|
writer, err = createOutputFile()
|
|
assert.Error(t, err)
|
|
assert.Nil(t, writer)
|
|
})
|
|
}
|
|
|
|
func TestValidationFunctions(t *testing.T) {
|
|
t.Run("validateFieldLength", func(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
fieldName string
|
|
field []byte
|
|
expectedLength int
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "nil field",
|
|
fieldName: "test",
|
|
field: nil,
|
|
expectedLength: 32,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "correct length",
|
|
fieldName: "test",
|
|
field: make([]byte, 32),
|
|
expectedLength: 32,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "incorrect length",
|
|
fieldName: "test",
|
|
field: make([]byte, 16),
|
|
expectedLength: 32,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validateFieldLength(tt.fieldName, tt.field, tt.expectedLength)
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), tt.fieldName)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDecodeJWTToJSONEdgeCases(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input []byte
|
|
expected string
|
|
hasError bool
|
|
}{
|
|
{
|
|
name: "empty input",
|
|
input: []byte(""),
|
|
hasError: true,
|
|
},
|
|
{
|
|
name: "single part",
|
|
input: []byte("onlyonepart"),
|
|
hasError: true,
|
|
},
|
|
{
|
|
name: "invalid base64 in header",
|
|
input: []byte("invalid@base64.validpart"),
|
|
hasError: true,
|
|
},
|
|
{
|
|
name: "invalid base64 in payload",
|
|
input: []byte("eyJhbGciOiJIUzI1NiJ9.invalid@base64"),
|
|
hasError: true,
|
|
},
|
|
{
|
|
name: "invalid JSON in header",
|
|
input: []byte("bm90anNvbg.eyJzdWIiOiJ0ZXN0In0"),
|
|
hasError: true,
|
|
},
|
|
{
|
|
name: "invalid JSON in payload",
|
|
input: []byte("eyJhbGciOiJIUzI1NiJ9.bm90anNvbg"),
|
|
hasError: true,
|
|
},
|
|
{
|
|
name: "valid JWT with padding",
|
|
input: []byte("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature"),
|
|
expected: `{
|
|
"header": {
|
|
"alg": "HS256"
|
|
},
|
|
"payload": {
|
|
"sub": "test"
|
|
}
|
|
}`,
|
|
hasError: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, err := decodeJWTToJSON(tt.input)
|
|
if tt.hasError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
if tt.expected != "" {
|
|
assert.JSONEq(t, tt.expected, string(result))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMeasureCmdEdgeCases(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
args []string
|
|
mockSetup func(*mmocks.MeasurementProvider)
|
|
expectedError string
|
|
expectedOut string
|
|
}{
|
|
{
|
|
name: "no arguments",
|
|
args: []string{},
|
|
mockSetup: func(m *mmocks.MeasurementProvider) {
|
|
},
|
|
expectedError: "requires at least 1 arg(s), only received 0",
|
|
},
|
|
{
|
|
name: "single line output success",
|
|
args: []string{"test.igvm"},
|
|
mockSetup: func(m *mmocks.MeasurementProvider) {
|
|
m.On("Run", "test.igvm").Return([]byte("ABCDEF123456"), nil)
|
|
},
|
|
expectedOut: "",
|
|
},
|
|
{
|
|
name: "multi-line output error",
|
|
args: []string{"test.igvm"},
|
|
mockSetup: func(m *mmocks.MeasurementProvider) {
|
|
m.On("Run", "test.igvm").Return([]byte("line1\nline2\nERROR: something went wrong"), nil)
|
|
},
|
|
expectedError: "ERROR: something went wrong",
|
|
},
|
|
{
|
|
name: "measurement run error",
|
|
args: []string{"test.igvm"},
|
|
mockSetup: func(m *mmocks.MeasurementProvider) {
|
|
m.On("Run", "test.igvm").Return(nil, errors.New("measurement failed"))
|
|
},
|
|
expectedError: "measurement failed",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockMeasurement := new(mmocks.MeasurementProvider)
|
|
tt.mockSetup(mockMeasurement)
|
|
|
|
cli := &CLI{measurement: mockMeasurement}
|
|
cmd := cli.NewMeasureCmd("fake_binary_path")
|
|
|
|
var buf bytes.Buffer
|
|
cmd.SetOut(&buf)
|
|
cmd.SetErr(&buf)
|
|
cmd.SetArgs(tt.args)
|
|
|
|
err := cmd.Execute()
|
|
|
|
if tt.expectedError != "" {
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), tt.expectedError)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
if tt.expectedOut != "" {
|
|
assert.Contains(t, buf.String(), tt.expectedOut)
|
|
}
|
|
}
|
|
|
|
mockMeasurement.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateAttestationValidationCmdPreRunE(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
args []string
|
|
flags map[string]string
|
|
expectedErr string
|
|
}{
|
|
{
|
|
name: "no file path provided",
|
|
args: []string{},
|
|
flags: map[string]string{"mode": "snp"},
|
|
expectedErr: "please pass the attestation report file path",
|
|
},
|
|
{
|
|
name: "multiple file paths",
|
|
args: []string{"file1.bin", "file2.bin"},
|
|
flags: map[string]string{"mode": "snp"},
|
|
expectedErr: "please pass the attestation report file path",
|
|
},
|
|
{
|
|
name: "unknown mode",
|
|
args: []string{"test.bin"},
|
|
flags: map[string]string{"mode": "unknown"},
|
|
expectedErr: "unknown mode: unknown",
|
|
},
|
|
{
|
|
name: "SNP mode missing report_data",
|
|
args: []string{"test.bin"},
|
|
flags: map[string]string{"mode": "snp"},
|
|
expectedErr: "",
|
|
},
|
|
{
|
|
name: "SNP mode missing product",
|
|
args: []string{"test.bin"},
|
|
flags: map[string]string{"mode": "snp", "report_data": "123"},
|
|
expectedErr: "",
|
|
},
|
|
{
|
|
name: "vTPM mode missing nonce",
|
|
args: []string{"test.bin"},
|
|
flags: map[string]string{"mode": "vtpm"},
|
|
expectedErr: "",
|
|
},
|
|
{
|
|
name: "SNP-vTPM mode missing required flags",
|
|
args: []string{"test.bin"},
|
|
flags: map[string]string{"mode": "snp-vtpm"},
|
|
expectedErr: "",
|
|
},
|
|
{
|
|
name: "TDX mode missing report_data",
|
|
args: []string{"test.bin"},
|
|
flags: map[string]string{"mode": "tdx"},
|
|
expectedErr: "",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cli := &CLI{}
|
|
cmd := cli.NewValidateAttestationValidationCmd()
|
|
|
|
for key, value := range tt.flags {
|
|
if err := cmd.Flags().Set(key, value); err != nil {
|
|
}
|
|
}
|
|
|
|
err := cmd.PreRunE(cmd, tt.args)
|
|
if tt.expectedErr != "" {
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), tt.expectedErr)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCloudProviderConfigurations(t *testing.T) {
|
|
defer setupTestEnvironment()()
|
|
|
|
tests := []struct {
|
|
name string
|
|
cloud string
|
|
expectedType string
|
|
}{
|
|
{
|
|
name: "none cloud provider",
|
|
cloud: CCNone,
|
|
expectedType: "vtpm",
|
|
},
|
|
{
|
|
name: "azure cloud provider",
|
|
cloud: CCAzure,
|
|
expectedType: "azure",
|
|
},
|
|
{
|
|
name: "gcp cloud provider",
|
|
cloud: CCGCP,
|
|
expectedType: "vtpm",
|
|
},
|
|
{
|
|
name: "default cloud provider",
|
|
cloud: "",
|
|
expectedType: "vtpm",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cli := &CLI{}
|
|
cmd := cli.NewValidateAttestationValidationCmd()
|
|
|
|
if err := cmd.Flags().Set("cloud", tt.cloud); err != nil {
|
|
t.Fatalf("Failed to set cloud flag: %v", err)
|
|
}
|
|
cloud, _ := cmd.Flags().GetString("cloud")
|
|
assert.Equal(t, tt.cloud, cloud)
|
|
|
|
assert.Contains(t, cmd.Short, tt.cloud)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFileOperationErrors(t *testing.T) {
|
|
defer setupTestEnvironment()()
|
|
|
|
t.Run("file close error handling", func(t *testing.T) {
|
|
tempFile := createTempFile(t, []byte("test content"))
|
|
defer os.Remove(tempFile)
|
|
|
|
assert.True(t, true)
|
|
})
|
|
|
|
t.Run("file write error handling", func(t *testing.T) {
|
|
tempFile := createTempFile(t, []byte("test content"))
|
|
defer os.Remove(tempFile)
|
|
|
|
err := os.Chmod(tempFile, 0o444)
|
|
require.NoError(t, err)
|
|
|
|
err = os.WriteFile(tempFile, []byte("new content"), 0o644)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("file read error handling", func(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
|
|
_, err := os.ReadFile(tempDir)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestContextCancellation(t *testing.T) {
|
|
defer setupTestEnvironment()()
|
|
|
|
mockSDK := new(mocks.SDK)
|
|
cli := &CLI{agentSDK: mockSDK}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
mockSDK.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
|
Return(context.Canceled)
|
|
|
|
cmd := cli.NewGetAttestationCmd()
|
|
cmd.SetContext(ctx)
|
|
|
|
var buf bytes.Buffer
|
|
cmd.SetOut(&buf)
|
|
cmd.SetErr(&buf)
|
|
|
|
teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
|
|
cmd.SetArgs([]string{"snp", "--tee", teeNonceHex})
|
|
|
|
err := cmd.Execute()
|
|
assert.NoError(t, err)
|
|
assert.Contains(t, buf.String(), "Failed to get attestation due to error")
|
|
}
|