mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
committed by
GitHub
parent
034547d667
commit
7ef25674c4
+10
-6
@@ -154,26 +154,30 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
Example: "get <report_data>",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
log.Println("Getting attestation")
|
||||
cmd.Println("Getting attestation")
|
||||
|
||||
reportData, err := hex.DecodeString(args[0])
|
||||
if err != nil {
|
||||
log.Fatalf("attestation validation and verification failed with error: %s", err)
|
||||
cmd.Printf("attestation validation and verification failed with error: %s", err)
|
||||
return
|
||||
}
|
||||
if len(reportData) != agent.ReportDataSize {
|
||||
log.Fatalf("report data must be a hex encoded string of length %d bytes", agent.ReportDataSize)
|
||||
cmd.Printf("report data must be a hex encoded string of length %d bytes", agent.ReportDataSize)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData))
|
||||
if err != nil {
|
||||
log.Fatalf("Error retrieving attestation: %v", err)
|
||||
cmd.Printf("Error retrieving attestation: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = os.WriteFile(attestationFilePath, result, 0o644); err != nil {
|
||||
log.Fatalf("Error saving attestation result: %v", err)
|
||||
cmd.Printf("Error saving attestation result: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("Attestation result retrieved and saved successfully!")
|
||||
cmd.Println("Attestation result retrieved and saved successfully!")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,190 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
||||
)
|
||||
|
||||
func TestNewAttestationCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewAttestationCmd()
|
||||
|
||||
assert.Equal(t, "attestation [command]", cmd.Use)
|
||||
assert.Equal(t, "Get and validate attestations", cmd.Short)
|
||||
}
|
||||
|
||||
func TestNewGetAttestationCmd(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
cmd := cli.NewGetAttestationCmd()
|
||||
var buf bytes.Buffer
|
||||
|
||||
cmd.SetOutput(&buf)
|
||||
|
||||
assert.Equal(t, "get", cmd.Use)
|
||||
assert.Equal(t, "Retrieve attestation information from agent. Report data expected in hex enoded string of length 64 bytes.", cmd.Short)
|
||||
|
||||
reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize)
|
||||
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData)).Return([]byte("mock attestation"), nil)
|
||||
|
||||
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Contains(t, buf.String(), "Attestation result retrieved and saved successfully!")
|
||||
|
||||
os.Remove(attestationFilePath)
|
||||
}
|
||||
|
||||
func TestNewValidateAttestationValidationCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
assert.Equal(t, "validate", cmd.Use)
|
||||
assert.Equal(t, "Validate and verify attestation information. The report is provided as a file path.", cmd.Short)
|
||||
|
||||
assert.Equal(t, fmt.Sprint(defaultMinimumTcb), cmd.Flag("minimum_tcb").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultMinimumLaunchTcb), cmd.Flag("minimum_lauch_tcb").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultGuestPolicy), cmd.Flag("guest_policy").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultMinimumGuestSvn), cmd.Flag("minimum_guest_svn").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultMinimumBuild), cmd.Flag("minimum_build").Value.String())
|
||||
assert.Equal(t, defaultCheckCrl, cmd.Flag("check_crl").Value.String() == "true")
|
||||
assert.Equal(t, fmt.Sprint(defaultTimeout), cmd.Flag("timeout").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultMaxRetryDelay), cmd.Flag("max_retry_delay").Value.String())
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
cfgString = ""
|
||||
err := parseConfig()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg.RootOfTrust)
|
||||
assert.NotNil(t, cfg.Policy)
|
||||
|
||||
cfgString = `{"rootOfTrust":{"product":"test_product"},"policy":{"minimumGuestSvn":1}}`
|
||||
err = parseConfig()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_product", cfg.RootOfTrust.Product)
|
||||
assert.Equal(t, uint32(1), cfg.Policy.MinimumGuestSvn)
|
||||
|
||||
cfgString = `{"invalid_json"`
|
||||
err = parseConfig()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseHashes(t *testing.T) {
|
||||
trustedAuthorHashes = []string{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}
|
||||
trustedIdKeyHashes = []string{"fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"}
|
||||
|
||||
cfg = check.Config{}
|
||||
if cfg.Policy == nil {
|
||||
cfg.Policy = &check.Policy{}
|
||||
}
|
||||
|
||||
err := parseHashes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, cfg.Policy.TrustedAuthorKeyHashes, 1)
|
||||
assert.Len(t, cfg.Policy.TrustedIdKeyHashes, 1)
|
||||
|
||||
trustedAuthorHashes = []string{"invalid_hash"}
|
||||
err = parseHashes()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseFiles(t *testing.T) {
|
||||
attestationFile = "test_attestation.bin"
|
||||
authorKeyFile := "test_author_key.pem"
|
||||
idKeyFile := "test_id_key.pem"
|
||||
|
||||
err := os.WriteFile(attestationFile, []byte("test attestation"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(authorKeyFile, []byte("test author key"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(idKeyFile, []byte("test id key"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
trustedAuthorKeys = []string{authorKeyFile}
|
||||
trustedIdKeys = []string{idKeyFile}
|
||||
|
||||
err = parseFiles()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("test attestation"), attestation)
|
||||
assert.Len(t, cfg.Policy.TrustedAuthorKeys, 1)
|
||||
assert.Len(t, cfg.Policy.TrustedIdKeys, 1)
|
||||
|
||||
os.Remove(attestationFile)
|
||||
os.Remove(authorKeyFile)
|
||||
os.Remove(idKeyFile)
|
||||
|
||||
attestationFile = "non_existent_file.bin"
|
||||
err = parseFiles()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseUints(t *testing.T) {
|
||||
stepping = "10"
|
||||
platformInfo = "0xFF"
|
||||
|
||||
cfg = check.Config{}
|
||||
if cfg.Policy == nil {
|
||||
cfg.Policy = &check.Policy{
|
||||
Product: &sevsnp.SevProduct{},
|
||||
}
|
||||
}
|
||||
err := parseUints()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(10), cfg.Policy.Product.MachineStepping.Value)
|
||||
assert.Equal(t, uint64(255), cfg.Policy.PlatformInfo.Value)
|
||||
|
||||
stepping = "invalid"
|
||||
err = parseUints()
|
||||
assert.Error(t, err)
|
||||
|
||||
stepping = "10"
|
||||
platformInfo = "invalid"
|
||||
err = parseUints()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
cfg = check.Config{}
|
||||
if cfg.Policy == nil {
|
||||
cfg.Policy = &check.Policy{}
|
||||
}
|
||||
if cfg.RootOfTrust == nil {
|
||||
cfg.RootOfTrust = &check.RootOfTrust{}
|
||||
}
|
||||
cfg.Policy.ReportData = make([]byte, 64)
|
||||
cfg.Policy.HostData = make([]byte, 32)
|
||||
cfg.Policy.FamilyId = make([]byte, 16)
|
||||
cfg.Policy.ImageId = make([]byte, 16)
|
||||
cfg.Policy.ReportId = make([]byte, 32)
|
||||
cfg.Policy.ReportIdMa = make([]byte, 32)
|
||||
cfg.Policy.Measurement = make([]byte, 48)
|
||||
cfg.Policy.ChipId = make([]byte, 64)
|
||||
|
||||
err := validateInput()
|
||||
assert.NoError(t, err)
|
||||
|
||||
cfg.Policy.ReportData = make([]byte, 32)
|
||||
err = validateInput()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetBase(t *testing.T) {
|
||||
assert.Equal(t, 16, getBase("0xFF"))
|
||||
assert.Equal(t, 8, getBase("0o77"))
|
||||
assert.Equal(t, 2, getBase("0b1010"))
|
||||
assert.Equal(t, 10, getBase("123"))
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChangeAttestationConfiguration(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "backend_info.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
initialConfig := AttestationConfiguration{
|
||||
SNPPolicy: &check.Policy{
|
||||
Measurement: make([]byte, measurementLength),
|
||||
HostData: make([]byte, hostDataLength),
|
||||
},
|
||||
}
|
||||
|
||||
initialJSON, err := json.Marshal(initialConfig)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(tmpfile.Name(), initialJSON, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
base64Data string
|
||||
expectedLength int
|
||||
field fieldType
|
||||
expectError bool
|
||||
errorType error
|
||||
}{
|
||||
{
|
||||
name: "Valid Measurement",
|
||||
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength)),
|
||||
expectedLength: measurementLength,
|
||||
field: measurementField,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Host Data",
|
||||
base64Data: base64.StdEncoding.EncodeToString(make([]byte, hostDataLength)),
|
||||
expectedLength: hostDataLength,
|
||||
field: hostDataField,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid Base64",
|
||||
base64Data: "Invalid Base64",
|
||||
expectedLength: measurementLength,
|
||||
field: measurementField,
|
||||
expectError: true,
|
||||
errorType: errDecode,
|
||||
},
|
||||
{
|
||||
name: "Invalid Data Length",
|
||||
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength-1)),
|
||||
expectedLength: measurementLength,
|
||||
field: measurementField,
|
||||
expectError: true,
|
||||
errorType: errDataLength,
|
||||
},
|
||||
{
|
||||
name: "Invalid Field Type",
|
||||
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength)),
|
||||
expectedLength: measurementLength,
|
||||
field: fieldType(999),
|
||||
expectError: true,
|
||||
errorType: errBackendField,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := changeAttestationConfiguration(tmpfile.Name(), tt.base64Data, tt.expectedLength, tt.field)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, tt.errorType)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(tmpfile.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
var config AttestationConfiguration
|
||||
err = json.Unmarshal(content, &config)
|
||||
require.NoError(t, err)
|
||||
|
||||
decodedData, _ := base64.StdEncoding.DecodeString(tt.base64Data)
|
||||
if tt.field == measurementField {
|
||||
assert.Equal(t, decodedData, config.SNPPolicy.Measurement)
|
||||
} else if tt.field == hostDataField {
|
||||
assert.Equal(t, decodedData, config.SNPPolicy.HostData)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBackendCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewBackendCmd()
|
||||
|
||||
assert.Equal(t, "backend [command]", cmd.Use)
|
||||
assert.Equal(t, "Change backend information", cmd.Short)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
}
|
||||
|
||||
func TestNewAddMeasurementCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewAddMeasurementCmd()
|
||||
|
||||
assert.Equal(t, "measurement", cmd.Use)
|
||||
assert.Equal(t, "Add measurement to the backend info file. The value should be in base64. The second parameter is backend_info.json file", cmd.Short)
|
||||
assert.Equal(t, "measurement <measurement> <backend_info.json>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
}
|
||||
|
||||
func TestNewAddHostDataCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewAddHostDataCmd()
|
||||
|
||||
assert.Equal(t, "hostdata", cmd.Use)
|
||||
assert.Equal(t, "Add host data to the backend info file. The value should be in base64. The second parameter is backend_info.json file", cmd.Short)
|
||||
assert.Equal(t, "hostdata <host-data> <backend_info.json>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewCABundleCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
tempDir, err := os.MkdirTemp("", "ca-bundle-test")
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
manifestContent := []byte(`{"root_of_trust": {"product": "Milan"}}`)
|
||||
manifestPath := path.Join(tempDir, "manifest.json")
|
||||
err = os.WriteFile(manifestPath, manifestContent, 0o644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cmd := cli.NewCABundleCmd(tempDir)
|
||||
cmd.SetArgs([]string{manifestPath})
|
||||
output := &bytes.Buffer{}
|
||||
cmd.SetOutput(output)
|
||||
err = cmd.Execute()
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedFilePath := path.Join(tempDir, "Milan", caBundleName)
|
||||
_, err = os.Stat(expectedFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(expectedFilePath)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, content)
|
||||
}
|
||||
|
||||
func TestSaveToFile(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "save-to-file-test")
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
filePath := path.Join(tempDir, "test-file.txt")
|
||||
content := []byte("test content")
|
||||
|
||||
err = saveToFile(filePath, content)
|
||||
assert.NoError(t, err)
|
||||
|
||||
savedContent, err := os.ReadFile(filePath)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, content, savedContent)
|
||||
|
||||
_, err = os.Stat(filePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
+3
-4
@@ -3,8 +3,6 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
)
|
||||
@@ -20,10 +18,11 @@ func (cli *CLI) NewFileHashCmd() *cobra.Command {
|
||||
|
||||
hash, err := internal.ChecksumHex(path)
|
||||
if err != nil {
|
||||
log.Fatalf("Error computing hash: %v", err)
|
||||
cmd.Printf("Error computing hash: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("Hash of file:", hash)
|
||||
cmd.Println("Hash of file:", hash)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
)
|
||||
|
||||
func TestNewFileHashCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewFileHashCmd()
|
||||
|
||||
if cmd.Use != "checksum" {
|
||||
t.Errorf("Expected Use to be 'checksum', got %s", cmd.Use)
|
||||
}
|
||||
|
||||
if cmd.Short != "Compute the sha3-256 hash of a file" {
|
||||
t.Errorf("Expected Short to be 'Compute the sha3-256 hash of a file', got %s", cmd.Short)
|
||||
}
|
||||
|
||||
if cmd.Example != "checksum <file>" {
|
||||
t.Errorf("Expected Example to be 'checksum <file>', got %s", cmd.Example)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileHashCmdRun(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewFileHashCmd()
|
||||
|
||||
content := []byte("test content")
|
||||
tmpfile, err := os.CreateTemp("", "example")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write(content); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
cmd.SetOut(&output)
|
||||
cmd.SetErr(&output)
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name()})
|
||||
err = cmd.Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing command: %v", err)
|
||||
}
|
||||
|
||||
expectedHash, err := internal.ChecksumHex(tmpfile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Error computing expected hash: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(output.String(), expectedHash) {
|
||||
t.Errorf("Expected output to contain hash %s, got %s", expectedHash, output.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileHashCmdInvalidArgs(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewFileHashCmd()
|
||||
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Error("Expected error when executing without arguments, got nil")
|
||||
}
|
||||
|
||||
cmd.SetArgs([]string{"file1", "file2"})
|
||||
err = cmd.Execute()
|
||||
if err == nil {
|
||||
t.Error("Expected error when executing with too many arguments, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileHashCmdNonExistentFile(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewFileHashCmd()
|
||||
|
||||
var output bytes.Buffer
|
||||
cmd.SetOut(&output)
|
||||
cmd.SetErr(&output)
|
||||
|
||||
cmd.SetArgs([]string{"non_existent_file.txt"})
|
||||
err := cmd.Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Error executing command: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(output.String(), "Error computing hash") {
|
||||
t.Errorf("Expected output to contain 'Error computing hash', got %s", output.String())
|
||||
}
|
||||
}
|
||||
+4
-2
@@ -25,6 +25,8 @@ const (
|
||||
publicKeyType = "PUBLIC KEY"
|
||||
publicKeyFile = "public.pem"
|
||||
privateKeyFile = "private.pem"
|
||||
ECDSA = "ecdsa"
|
||||
ED25519 = "ed25519"
|
||||
)
|
||||
|
||||
var KeyType string
|
||||
@@ -39,7 +41,7 @@ func (cli *CLI) NewKeysCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
switch KeyType {
|
||||
case "ecdsa":
|
||||
case ECDSA:
|
||||
privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
log.Fatalf("Error generating keys: %v", err)
|
||||
@@ -52,7 +54,7 @@ func (cli *CLI) NewKeysCmd() *cobra.Command {
|
||||
|
||||
generateAndWriteKeys(privEcdsaKey, pubKeyBytes, ecdsaKeyType)
|
||||
|
||||
case "ed25519":
|
||||
case ED25519:
|
||||
pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
log.Fatalf("Error generating keys: %v", err)
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewKeysCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewKeysCmd()
|
||||
|
||||
if cmd.Use != "keys" {
|
||||
t.Errorf("Expected Use to be 'keys', got %s", cmd.Use)
|
||||
}
|
||||
|
||||
if cmd.Short != "Generate a new public/private key pair" {
|
||||
t.Errorf("Unexpected Short description: %s", cmd.Short)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAndWriteKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType string
|
||||
}{
|
||||
{"RSA", "rsa"},
|
||||
{"ECDSA", "ecdsa"},
|
||||
{"ED25519", "ed25519"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
KeyType = tt.keyType
|
||||
cmd := (&CLI{}).NewKeysCmd()
|
||||
cmd.Run(cmd, []string{})
|
||||
|
||||
if _, err := os.Stat(privateKeyFile); os.IsNotExist(err) {
|
||||
t.Errorf("Private key file was not created")
|
||||
}
|
||||
if _, err := os.Stat(publicKeyFile); os.IsNotExist(err) {
|
||||
t.Errorf("Public key file was not created")
|
||||
}
|
||||
|
||||
privKeyData, err := os.ReadFile(privateKeyFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read private key file: %v", err)
|
||||
}
|
||||
privPem, _ := pem.Decode(privKeyData)
|
||||
if privPem == nil {
|
||||
t.Fatalf("Failed to decode private key PEM")
|
||||
}
|
||||
|
||||
var privKey interface{}
|
||||
switch tt.keyType {
|
||||
case "rsa":
|
||||
privKey, err = x509.ParsePKCS1PrivateKey(privPem.Bytes)
|
||||
case "ecdsa":
|
||||
privKey, err = x509.ParseECPrivateKey(privPem.Bytes)
|
||||
case "ed25519":
|
||||
privKey, err = x509.ParsePKCS8PrivateKey(privPem.Bytes)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse private key: %v", err)
|
||||
}
|
||||
|
||||
switch tt.keyType {
|
||||
case "rsa":
|
||||
if _, ok := privKey.(*rsa.PrivateKey); !ok {
|
||||
t.Errorf("Expected RSA private key, got %T", privKey)
|
||||
}
|
||||
case "ecdsa":
|
||||
if _, ok := privKey.(*ecdsa.PrivateKey); !ok {
|
||||
t.Errorf("Expected ECDSA private key, got %T", privKey)
|
||||
}
|
||||
case "ed25519":
|
||||
if _, ok := privKey.(ed25519.PrivateKey); !ok {
|
||||
t.Errorf("Expected ED25519 private key, got %T", privKey)
|
||||
}
|
||||
}
|
||||
|
||||
os.Remove(privateKeyFile)
|
||||
os.Remove(publicKeyFile)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user