mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
* Implement IMAMeasurements method in agentSDK and add corresponding unit tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for NewIMAMeasurements command in CLI Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add error assertion for command execution in NewIMAMeasurements test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix nil pointer dereference in Close method and update NewCreateVMCmd logic for manager client initialization Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor file permission settings to use octal notation and improve cleanup handling in NewCreateVMCmd test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive unit tests for state machine functionality Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add mock implementation for Algorithm interface and corresponding test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor file permission settings to use octal notation in TestStopComputationIntegration Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove redundant reset test cases from TestStateMachine_Reset Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix race condition in action call verification in TestStateMachine_HandleEvent Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance state machine with reset functionality and improve thread safety in event handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in state machine start function during tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove concurrent reset and send event test from state machine tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove error logging for Start function in transition tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add mock implementations for AgentService_IMAMeasurementsClient and Service Shutdown method; enhance progress tests for IMA measurements handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for FileStorage functionality including loading, saving, and concurrent access Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance tests by adding dataset and algorithm hashes in handleRunReqChunks; improve error handling in TestFileStorage_ErrorHandling cleanup Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance TestManagerClient_Process by adding new test cases for Agent state and Disconnect requests; update setupMocks to include grpcClient Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix graceful shutdown in gRPC server by adding nil checks for health and server instances Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance TestAttestation by adding mock expectations for VTpmAttestation and Attestation methods; update service call to include platform parameter Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance gRPC Server by adding synchronization for start/stop methods; prevent multiple starts and ensure graceful shutdown Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for gRPC server methods including VM creation, removal, and info retrieval Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add tests for SEVSNP and TDX host capabilities; remove unused vsock code Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add a newline for better readability in vm_test.go Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add integration tests for gRPC client in cvm_test.go Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove unused vsock dependencies and add comprehensive unit tests for GCP attestation functions Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Skip GCP tests if credentials are not set Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add tests for error handling in attestation configuration and GCP commands Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in Azure VM test response writing Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Skip tests in GCP functions if credentials are not set Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive unit tests for Azure attestation provider and verifier Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for TPM functionality and improve error handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for attestation functionality and improve error handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add validation for teeNonce in TeeAttestation and implement comprehensive tests for provider methods Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor error messages in TDX attestation tests for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix error message in TeeAttestation test for valid nonce case Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add MeasurementProvider mock and update mockery configuration Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add logging for product in parseUints and rename test functions for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor TestSevsnpverify to reset configuration and improve error logging Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
85a2b7a6c8
commit
4e8057f481
@@ -3,19 +3,33 @@
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
certssdk "github.com/absmach/certs/sdk"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
@@ -217,6 +231,346 @@ func TestGetPlatformTypeFromOID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyCertificateExtension(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "policy")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
attestationPB := prepVerifyAttReport(t)
|
||||
err = setAttestationPolicy(attestationPB, tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
_, err = rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
teeNonce := append(pubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
extension []byte
|
||||
pubKey []byte
|
||||
nonce []byte
|
||||
platformType attestation.PlatformType
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid extension with SNPvTPM",
|
||||
extension: hashNonce[:],
|
||||
pubKey: pubKeyDER,
|
||||
nonce: nonce,
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid platform type",
|
||||
extension: hashNonce[:],
|
||||
pubKey: pubKeyDER,
|
||||
nonce: nonce,
|
||||
platformType: 999,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty extension",
|
||||
extension: []byte{},
|
||||
pubKey: pubKeyDER,
|
||||
nonce: nonce,
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty public key",
|
||||
extension: hashNonce[:],
|
||||
pubKey: []byte{},
|
||||
nonce: nonce,
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty nonce",
|
||||
extension: hashNonce[:],
|
||||
pubKey: pubKeyDER,
|
||||
nonce: []byte{},
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
err := VerifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
|
||||
if c.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateExtension(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation-data"), nil)
|
||||
|
||||
pubKey := []byte("test-public-key")
|
||||
nonce := make([]byte, 32)
|
||||
_, err := rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
testOID := asn1.ObjectIdentifier{1, 2, 3, 4}
|
||||
|
||||
extension, err := getCertificateExtension(mockProvider, pubKey, nonce, testOID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testOID, extension.Id)
|
||||
assert.Equal(t, []byte("mock-attestation-data"), extension.Value)
|
||||
}
|
||||
|
||||
func TestGetCertificateWithSelfSigned(t *testing.T) {
|
||||
getCertFunc := GetCertificate("", "")
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
_, err := rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverName := hex.EncodeToString(nonce) + ".nonce"
|
||||
|
||||
clientHello := &tls.ClientHelloInfo{
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
cert, err := getCertFunc(clientHello)
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Expected error due to missing attestation setup: %v", err)
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NotNil(t, cert)
|
||||
assert.NotEmpty(t, cert.Certificate)
|
||||
assert.NotNil(t, cert.PrivateKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateWithCA(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
mockCert := certssdk.Certificate{
|
||||
Certificate: "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIBATANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDDA1UZXN0IENBIFJvb3QwHhcNMjMwMzMxMDAwMDAwWhcNMjQwMzMxMDAwMDAwWjAYMRYwFAYDVQQDDA1UZXN0IENlcnRpZmljYXRlMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEtest-key-data-here\n-----END CERTIFICATE-----",
|
||||
}
|
||||
|
||||
response, _ := json.Marshal(mockCert)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(response); err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
getCertFunc := GetCertificate(mockServer.URL, "test-cvm-id")
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
_, err := rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverName := hex.EncodeToString(nonce) + ".nonce"
|
||||
|
||||
clientHello := &tls.ClientHelloInfo{
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
_, err = getCertFunc(clientHello)
|
||||
if err != nil {
|
||||
t.Logf("Expected error due to missing attestation setup: %v", err)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateInvalidServerName(t *testing.T) {
|
||||
getCertFunc := GetCertificate("", "")
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
serverName string
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "Missing .nonce suffix",
|
||||
serverName: "invalidname",
|
||||
expectErr: "failed to get platform provider",
|
||||
},
|
||||
{
|
||||
name: "Too short server name",
|
||||
serverName: "short",
|
||||
expectErr: "failed to get platform provider",
|
||||
},
|
||||
{
|
||||
name: "Invalid nonce encoding",
|
||||
serverName: "invalidhex.nonce",
|
||||
expectErr: "failed to get platform provider",
|
||||
},
|
||||
{
|
||||
name: "Wrong nonce length",
|
||||
serverName: "deadbeef.nonce",
|
||||
expectErr: "failed to get platform provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
clientHello := &tls.ClientHelloInfo{
|
||||
ServerName: c.serverName,
|
||||
}
|
||||
|
||||
cert, err := getCertFunc(clientHello)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), c.expectErr)
|
||||
assert.Nil(t, cert)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessRequest(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/success":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte(`{"message": "success"}`)); err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case "/notfound":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
if _, err := w.Write([]byte(`{"error": "not found"}`)); err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case "/headers":
|
||||
if r.Header.Get("X-Custom-Header") == "test-value" {
|
||||
w.Header().Set("X-Response-Header", "received")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte(`{"headers": "ok"}`)); err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
default:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
method string
|
||||
url string
|
||||
data []byte
|
||||
headers map[string]string
|
||||
expectedRespCodes []int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Successful GET request",
|
||||
method: http.MethodGet,
|
||||
url: testServer.URL + "/success",
|
||||
data: nil,
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Successful POST request with data",
|
||||
method: http.MethodPost,
|
||||
url: testServer.URL + "/success",
|
||||
data: []byte(`{"test": "data"}`),
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Request with custom headers",
|
||||
method: http.MethodGet,
|
||||
url: testServer.URL + "/headers",
|
||||
data: nil,
|
||||
headers: map[string]string{"X-Custom-Header": "test-value"},
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Request with unexpected status code",
|
||||
method: http.MethodGet,
|
||||
url: testServer.URL + "/notfound",
|
||||
data: nil,
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Request with multiple expected status codes",
|
||||
method: http.MethodGet,
|
||||
url: testServer.URL + "/notfound",
|
||||
data: nil,
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK, http.StatusNotFound},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Request to invalid URL",
|
||||
method: http.MethodGet,
|
||||
url: "invalid-url",
|
||||
data: nil,
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
headers, body, err := processRequest(c.method, c.url, c.data, c.headers, c.expectedRespCodes...)
|
||||
|
||||
if c.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, headers)
|
||||
assert.NotNil(t, body)
|
||||
|
||||
if c.name == "Request with custom headers" {
|
||||
assert.Equal(t, "received", headers.Get("X-Response-Header"))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateExtensionError(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("failed to get attestation"))
|
||||
|
||||
pubKey := []byte("test-public-key")
|
||||
nonce := make([]byte, 32)
|
||||
testOID := asn1.ObjectIdentifier{1, 2, 3, 4}
|
||||
|
||||
extension, err := getCertificateExtension(mockProvider, pubKey, nonce, testOID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to get attestation")
|
||||
assert.Equal(t, pkix.Extension{}, extension)
|
||||
}
|
||||
|
||||
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
|
||||
file, err := os.ReadFile("../../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package attestation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCCPlatform(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sevSnpGuestExists bool
|
||||
sevSnpGuestvTPMExists bool
|
||||
tdxGuestExists bool
|
||||
isAzure bool
|
||||
expected PlatformType
|
||||
}{
|
||||
{
|
||||
name: "No CC platform detected",
|
||||
sevSnpGuestExists: false,
|
||||
sevSnpGuestvTPMExists: false,
|
||||
tdxGuestExists: false,
|
||||
isAzure: false,
|
||||
expected: NoCC,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CCPlatform()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSevSnpGuestDeviceExists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
openDeviceErr error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "device does not exist or fails to open",
|
||||
openDeviceErr: fmt.Errorf("device not found"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SevSnpGuestDeviceExists()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSevSnpGuestvTPMExists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
vTPMExists bool
|
||||
sevSnpExists bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "vTPM exists but SEV-SNP does not",
|
||||
vTPMExists: true,
|
||||
sevSnpExists: false,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "SEV-SNP exists but vTPM does not",
|
||||
vTPMExists: false,
|
||||
sevSnpExists: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "neither exists",
|
||||
vTPMExists: false,
|
||||
sevSnpExists: false,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SevSnpGuestvTPMExists()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVTPMExists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
openTPMErr error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "TPM fails to open",
|
||||
openTPMErr: fmt.Errorf("TPM not found"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := vTPMExists()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAzureVM(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
vTPMExists bool
|
||||
statusCode int
|
||||
responseBody string
|
||||
httpError error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Azure VM with empty response body",
|
||||
vTPMExists: true,
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: "",
|
||||
httpError: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Azure VM with non-200 status code",
|
||||
vTPMExists: true,
|
||||
statusCode: http.StatusNotFound,
|
||||
responseBody: "",
|
||||
httpError: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP request error",
|
||||
vTPMExists: true,
|
||||
statusCode: 0,
|
||||
responseBody: "",
|
||||
httpError: fmt.Errorf("connection failed"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "vTPM does not exist",
|
||||
vTPMExists: false,
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: `{"compute":{"name":"test-vm"}}`,
|
||||
httpError: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
assert.Equal(t, "true", r.Header.Get("Metadata"))
|
||||
expectedURL := fmt.Sprintf("/?api-version=%s", azureApiVersion)
|
||||
assert.Equal(t, expectedURL, r.URL.String())
|
||||
|
||||
if tt.httpError != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(tt.statusCode)
|
||||
if tt.responseBody != "" {
|
||||
if _, err := w.Write([]byte(tt.responseBody)); err != nil {
|
||||
t.Fatalf("Failed to write response body: %v", err)
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
if tt.httpError != nil {
|
||||
server.Close()
|
||||
}
|
||||
|
||||
result := isAzureVM()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTDXGuestDeviceExists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
openDeviceErr error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "TDX device does not exist or fails to open",
|
||||
openDeviceErr: fmt.Errorf("device not found"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := TDXGuestDeviceExists()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,578 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
testNonce = []byte("test-nonce-12345678901234567890123456789012")
|
||||
testReport = []byte("test-report-data")
|
||||
)
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want attestation.Provider
|
||||
}{
|
||||
{
|
||||
name: "creates new provider successfully",
|
||||
want: provider{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewProvider()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Attestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "maa parameters error",
|
||||
teeNonce: testNonce,
|
||||
vTpmNonce: testNonce,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to get report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewProvider()
|
||||
|
||||
result, err := p.Attestation(tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_TeeAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeNonce []byte
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "maa parameters error",
|
||||
teeNonce: testNonce,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to get report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewProvider()
|
||||
|
||||
result, err := p.TeeAttestation(tt.teeNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureAttestationToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenNonce []byte
|
||||
setupServer func() *httptest.Server
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "server error",
|
||||
tokenNonce: testNonce,
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
},
|
||||
wantErr: true,
|
||||
errorMessage: "failed to fetch Azure token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := tt.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = server.URL
|
||||
defer func() { MaaURL = originalURL }()
|
||||
|
||||
p := NewProvider()
|
||||
|
||||
result, err := p.AzureAttestationToken(tt.tokenNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVerifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
writer io.Writer
|
||||
}{
|
||||
{
|
||||
name: "creates verifier with buffer writer",
|
||||
writer: &bytes.Buffer{},
|
||||
},
|
||||
{
|
||||
name: "creates verifier with nil writer",
|
||||
writer: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifier(tt.writer)
|
||||
|
||||
verifier, ok := v.(verifier)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.writer, verifier.writer)
|
||||
assert.NotNil(t, verifier.Policy)
|
||||
assert.NotNil(t, verifier.Policy.Config)
|
||||
assert.NotNil(t, verifier.Policy.PcrConfig)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVerifierWithPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
writer io.Writer
|
||||
policy *attestation.Config
|
||||
}{
|
||||
{
|
||||
name: "creates verifier with custom policy",
|
||||
writer: &bytes.Buffer{},
|
||||
policy: &attestation.Config{
|
||||
Config: &check.Config{
|
||||
Policy: &check.Policy{},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "creates verifier with nil policy",
|
||||
writer: &bytes.Buffer{},
|
||||
policy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifierWithPolicy(tt.writer, tt.policy)
|
||||
|
||||
verifier, ok := v.(verifier)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.writer, verifier.writer)
|
||||
assert.NotNil(t, verifier.Policy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifTeeAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
report []byte
|
||||
teeNonce []byte
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "empty report",
|
||||
report: []byte{},
|
||||
teeNonce: testNonce,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid report format",
|
||||
report: []byte("invalid-report"),
|
||||
teeNonce: testNonce,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil nonce",
|
||||
report: testReport,
|
||||
teeNonce: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifier(&bytes.Buffer{})
|
||||
|
||||
err := v.VerifTeeAttestation(tt.report, tt.teeNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifyAttestation(t *testing.T) {
|
||||
validQuote := &attest.Attestation{
|
||||
TeeAttestation: &attest.Attestation_SevSnpAttestation{
|
||||
SevSnpAttestation: &sevsnp.Attestation{
|
||||
Report: &sevsnp.Report{
|
||||
HostData: []byte("test-data"),
|
||||
},
|
||||
Product: &sevsnp.SevProduct{
|
||||
Name: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
|
||||
},
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
Extras: make(map[string][]byte),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
validReport, _ := proto.Marshal(validQuote)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
report []byte
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "successful verification",
|
||||
report: validReport,
|
||||
teeNonce: testNonce,
|
||||
vTpmNonce: testNonce,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to verify vTPM attestation report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifier(&bytes.Buffer{})
|
||||
|
||||
err := v.VerifyAttestation(tt.report, tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAzureAttestationToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenNonce []byte
|
||||
maaURL string
|
||||
setupServer func() *httptest.Server
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "server error",
|
||||
tokenNonce: testNonce,
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
},
|
||||
wantErr: true,
|
||||
errorMessage: "error fetching azure token",
|
||||
},
|
||||
{
|
||||
name: "invalid url",
|
||||
tokenNonce: testNonce,
|
||||
setupServer: func() *httptest.Server {
|
||||
return nil
|
||||
},
|
||||
wantErr: true,
|
||||
errorMessage: "error fetching azure token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var url string
|
||||
if tt.setupServer != nil {
|
||||
server := tt.setupServer()
|
||||
if server != nil {
|
||||
defer server.Close()
|
||||
url = server.URL
|
||||
}
|
||||
}
|
||||
|
||||
if tt.name == "invalid url" {
|
||||
url = "invalid-url"
|
||||
}
|
||||
|
||||
result, err := FetchAzureAttestationToken(tt.tokenNonce, url)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
setupServer func() *httptest.Server
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "invalid token format",
|
||||
token: "invalid-token",
|
||||
setupServer: nil,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to parse token",
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
setupServer: nil,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to parse token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setupServer != nil {
|
||||
server := tt.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = server.URL
|
||||
defer func() { MaaURL = originalURL }()
|
||||
}
|
||||
|
||||
result, err := validateToken(tt.token)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_FullAttestationFlow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
t.Run("full attestation flow with mock server", func(t *testing.T) {
|
||||
maaServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/attest":
|
||||
response := map[string]interface{}{
|
||||
"token": createMockJWT(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
t.Fatalf("Failed to encode response: %v", err)
|
||||
}
|
||||
case "/.well-known/openid_configuration":
|
||||
config := map[string]interface{}{
|
||||
"jwks_uri": "maaServer.URL" + "/certs",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(config); err != nil {
|
||||
t.Fatalf("Failed to encode OpenID configuration: %v", err)
|
||||
}
|
||||
case "/certs":
|
||||
jwks := map[string]interface{}{
|
||||
"keys": []map[string]interface{}{
|
||||
{
|
||||
"kid": "test-kid",
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"n": "test-n-value",
|
||||
"e": "AQAB",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(jwks); err != nil {
|
||||
t.Fatalf("Failed to encode JWKS: %v", err)
|
||||
}
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer maaServer.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = maaServer.URL
|
||||
defer func() { MaaURL = originalURL }()
|
||||
|
||||
provider := NewProvider()
|
||||
verifier := NewVerifier(&bytes.Buffer{})
|
||||
|
||||
teeNonce := []byte("test-tee-nonce-1234567890123456789012")
|
||||
vtpmNonce := []byte("test-vtpm-nonce-123456789012345678901")
|
||||
|
||||
teeReport, err := provider.TeeAttestation(teeNonce)
|
||||
if err != nil {
|
||||
t.Logf("TEE attestation failed (expected in mock environment): %v", err)
|
||||
}
|
||||
|
||||
vtpmReport, err := provider.VTpmAttestation(vtpmNonce)
|
||||
if err != nil {
|
||||
t.Logf("vTPM attestation failed (expected in mock environment): %v", err)
|
||||
}
|
||||
|
||||
token, err := provider.AzureAttestationToken(teeNonce)
|
||||
if err != nil {
|
||||
t.Logf("Azure attestation token failed (expected in mock environment): %v", err)
|
||||
}
|
||||
|
||||
assert.NotNil(t, provider)
|
||||
assert.NotNil(t, verifier)
|
||||
|
||||
t.Logf("TEE report length: %d", len(teeReport))
|
||||
t.Logf("vTPM report length: %d", len(vtpmReport))
|
||||
t.Logf("Token length: %d", len(token))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegration_ErrorPropagation(t *testing.T) {
|
||||
t.Run("error propagation through full stack", func(t *testing.T) {
|
||||
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
if _, err := w.Write([]byte("Internal Server Error")); err != nil {
|
||||
t.Fatalf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer failingServer.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = failingServer.URL
|
||||
defer func() { MaaURL = originalURL }()
|
||||
|
||||
provider := NewProvider()
|
||||
|
||||
_, err := provider.AzureAttestationToken([]byte("test-nonce"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to fetch Azure token")
|
||||
|
||||
_, err = GenerateAttestationPolicy("invalid-token", "test-product", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to validate token")
|
||||
})
|
||||
}
|
||||
|
||||
func createMockJWT() string {
|
||||
claims := jwt.MapClaims{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-audience",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"x-ms-isolation-tee": map[string]interface{}{
|
||||
"x-ms-sevsnpvm-familyId": "1234567890abcdef",
|
||||
"x-ms-sevsnpvm-imageId": "fedcba0987654321",
|
||||
"x-ms-sevsnpvm-launchmeasurement": "abcdef1234567890",
|
||||
"x-ms-sevsnpvm-bootloader-svn": float64(1),
|
||||
"x-ms-sevsnpvm-tee-svn": float64(2),
|
||||
"x-ms-sevsnpvm-snpfw-svn": float64(3),
|
||||
"x-ms-sevsnpvm-microcode-svn": float64(4),
|
||||
"x-ms-sevsnpvm-guestsvn": float64(5),
|
||||
"x-ms-sevsnpvm-idkeydigest": "1234567890abcdef",
|
||||
"x-ms-sevsnpvm-reportid": "fedcba0987654321",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["jku"] = "https://test-url.com"
|
||||
token.Header["kid"] = "test-kid"
|
||||
|
||||
// Return unsigned token for testing
|
||||
return token.Raw
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MeasurementProvider is an autogenerated mock type for the MeasurementProvider type
|
||||
type MeasurementProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MeasurementProvider_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MeasurementProvider) EXPECT() *MeasurementProvider_Expecter {
|
||||
return &MeasurementProvider_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Run provides a mock function with given fields: binaryPath
|
||||
func (_m *MeasurementProvider) Run(binaryPath string) ([]byte, error) {
|
||||
ret := _m.Called(binaryPath)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Run")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string) ([]byte, error)); ok {
|
||||
return rf(binaryPath)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string) []byte); ok {
|
||||
r0 = rf(binaryPath)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(binaryPath)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MeasurementProvider_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
|
||||
type MeasurementProvider_Run_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Run is a helper method to define mock.On call
|
||||
// - binaryPath string
|
||||
func (_e *MeasurementProvider_Expecter) Run(binaryPath interface{}) *MeasurementProvider_Run_Call {
|
||||
return &MeasurementProvider_Run_Call{Call: _e.mock.On("Run", binaryPath)}
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Run_Call) Run(run func(binaryPath string)) *MeasurementProvider_Run_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Run_Call) Return(_a0 []byte, _a1 error) *MeasurementProvider_Run_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Run_Call) RunAndReturn(run func(string) ([]byte, error)) *MeasurementProvider_Run_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function with no fields
|
||||
func (_m *MeasurementProvider) Stop() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MeasurementProvider_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type MeasurementProvider_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
func (_e *MeasurementProvider_Expecter) Stop() *MeasurementProvider_Stop_Call {
|
||||
return &MeasurementProvider_Stop_Call{Call: _e.mock.On("Stop")}
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Stop_Call) Run(run func()) *MeasurementProvider_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Stop_Call) Return(_a0 error) *MeasurementProvider_Stop_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Stop_Call) RunAndReturn(run func() error) *MeasurementProvider_Stop_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMeasurementProvider creates a new instance of MeasurementProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMeasurementProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MeasurementProvider {
|
||||
mock := &MeasurementProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package gcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"cloud.google.com/go/storage"
|
||||
"github.com/google/gce-tcb-verifier/proto/endorsement"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestExtract384BitMeasurement(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
setupMock func()
|
||||
expected string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "nil attestation",
|
||||
attestation: nil,
|
||||
expectError: true,
|
||||
errorMsg: "report is nil",
|
||||
},
|
||||
{
|
||||
name: "short report",
|
||||
attestation: &sevsnp.Attestation{Report: &sevsnp.Report{}},
|
||||
expectError: true,
|
||||
errorMsg: "failed to transform report to binary",
|
||||
},
|
||||
{
|
||||
name: "empty report",
|
||||
attestation: &sevsnp.Attestation{},
|
||||
expectError: true,
|
||||
errorMsg: "failed to transform report to binary",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := Extract384BitMeasurement(tt.attestation)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Empty(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLaunchEndorsement(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
measurement384 string
|
||||
setupMock func() ([]byte, error)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
measurement384: "test-measurement",
|
||||
setupMock: func() ([]byte, error) {
|
||||
goldenUEFI := &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 12345,
|
||||
Measurements: map[uint32][]byte{1: []byte("test-measurement")},
|
||||
},
|
||||
}
|
||||
goldenBytes, _ := proto.Marshal(goldenUEFI)
|
||||
|
||||
launchEndorsement := &endorsement.VMLaunchEndorsement{
|
||||
SerializedUefiGolden: goldenBytes,
|
||||
}
|
||||
return proto.Marshal(launchEndorsement)
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "storage client error",
|
||||
measurement384: "test-measurement",
|
||||
setupMock: func() ([]byte, error) {
|
||||
return nil, errors.New("storage client error")
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "object not found",
|
||||
measurement384: "non-existent-measurement",
|
||||
setupMock: func() ([]byte, error) {
|
||||
return nil, storage.ErrObjectNotExist
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "invalid protobuf data",
|
||||
measurement384: "test-measurement",
|
||||
setupMock: func() ([]byte, error) {
|
||||
return []byte("invalid protobuf data"), nil
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// skip if credentials are not set
|
||||
if _, err := storage.NewClient(ctx); err != nil && tt.expectError {
|
||||
t.Skip("Skipping test due to missing GCP credentials")
|
||||
}
|
||||
|
||||
_, err := GetLaunchEndorsement(ctx, tt.measurement384)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAttestationPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endorsement *endorsement.VMGoldenMeasurement
|
||||
vcpuNum uint32
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid endorsement",
|
||||
endorsement: &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 12345,
|
||||
Measurements: map[uint32][]byte{1: []byte("test-measurement")},
|
||||
},
|
||||
},
|
||||
vcpuNum: 1,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing measurement for vcpu",
|
||||
endorsement: &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 12345,
|
||||
Measurements: map[uint32][]byte{2: []byte("test-measurement")},
|
||||
},
|
||||
},
|
||||
vcpuNum: 1,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty measurements map",
|
||||
endorsement: &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 12345,
|
||||
Measurements: map[uint32][]byte{},
|
||||
},
|
||||
},
|
||||
vcpuNum: 1,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := GenerateAttestationPolicy(tt.endorsement, tt.vcpuNum)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotNil(t, result.Config)
|
||||
assert.NotNil(t, result.Config.Policy)
|
||||
assert.NotNil(t, result.Config.RootOfTrust)
|
||||
assert.NotNil(t, result.PcrConfig)
|
||||
|
||||
assert.Equal(t, tt.endorsement.SevSnp.Policy, result.Config.Policy.Policy)
|
||||
assert.Equal(t, tt.endorsement.SevSnp.Measurements[tt.vcpuNum], result.Config.Policy.Measurement)
|
||||
assert.False(t, result.Config.RootOfTrust.DisallowNetwork)
|
||||
assert.True(t, result.Config.RootOfTrust.CheckCrl)
|
||||
assert.Equal(t, "Milan", result.Config.RootOfTrust.Product)
|
||||
assert.Equal(t, "Milan", result.Config.RootOfTrust.ProductLine)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadOvmfFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
digest string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful download",
|
||||
digest: "test-digest",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "storage client error",
|
||||
digest: "test-digest",
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "object not found",
|
||||
digest: "non-existent-digest",
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "read error",
|
||||
digest: "test-digest",
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "empty digest",
|
||||
digest: "",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// skip if credentials are not set
|
||||
if _, err := storage.NewClient(ctx); err != nil && tt.expectError {
|
||||
t.Skip("Skipping test due to missing GCP credentials")
|
||||
}
|
||||
|
||||
_, err := DownloadOvmfFile(ctx, tt.digest)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -253,148 +253,6 @@ func (_c *Provider_VTpmAttestation_Call) RunAndReturn(run func([]byte) ([]byte,
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifTeeAttestation provides a mock function with given fields: report, teeNonce
|
||||
func (_m *Provider) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
ret := _m.Called(report, teeNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifTeeAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
|
||||
r0 = rf(report, teeNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Provider_VerifTeeAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifTeeAttestation'
|
||||
type Provider_VerifTeeAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifTeeAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - teeNonce []byte
|
||||
func (_e *Provider_Expecter) VerifTeeAttestation(report interface{}, teeNonce interface{}) *Provider_VerifTeeAttestation_Call {
|
||||
return &Provider_VerifTeeAttestation_Call{Call: _e.mock.On("VerifTeeAttestation", report, teeNonce)}
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifTeeAttestation_Call) Run(run func(report []byte, teeNonce []byte)) *Provider_VerifTeeAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifTeeAttestation_Call) Return(_a0 error) *Provider_VerifTeeAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifTeeAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Provider_VerifTeeAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifVTpmAttestation provides a mock function with given fields: report, vTpmNonce
|
||||
func (_m *Provider) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
ret := _m.Called(report, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifVTpmAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
|
||||
r0 = rf(report, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Provider_VerifVTpmAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifVTpmAttestation'
|
||||
type Provider_VerifVTpmAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifVTpmAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Provider_Expecter) VerifVTpmAttestation(report interface{}, vTpmNonce interface{}) *Provider_VerifVTpmAttestation_Call {
|
||||
return &Provider_VerifVTpmAttestation_Call{Call: _e.mock.On("VerifVTpmAttestation", report, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifVTpmAttestation_Call) Run(run func(report []byte, vTpmNonce []byte)) *Provider_VerifVTpmAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifVTpmAttestation_Call) Return(_a0 error) *Provider_VerifVTpmAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifVTpmAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Provider_VerifVTpmAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyAttestation provides a mock function with given fields: report, teeNonce, vTpmNonce
|
||||
func (_m *Provider) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
ret := _m.Called(report, teeNonce, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte, []byte) error); ok {
|
||||
r0 = rf(report, teeNonce, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Provider_VerifyAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAttestation'
|
||||
type Provider_VerifyAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - teeNonce []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Provider_Expecter) VerifyAttestation(report interface{}, teeNonce interface{}, vTpmNonce interface{}) *Provider_VerifyAttestation_Call {
|
||||
return &Provider_VerifyAttestation_Call{Call: _e.mock.On("VerifyAttestation", report, teeNonce, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifyAttestation_Call) Run(run func(report []byte, teeNonce []byte, vTpmNonce []byte)) *Provider_VerifyAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte), args[2].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifyAttestation_Call) Return(_a0 error) *Provider_VerifyAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifyAttestation_Call) RunAndReturn(run func([]byte, []byte, []byte) error) *Provider_VerifyAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewProvider creates a new instance of Provider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewProvider(t interface {
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// Verifier is an autogenerated mock type for the Verifier type
|
||||
type Verifier struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Verifier_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Verifier) EXPECT() *Verifier_Expecter {
|
||||
return &Verifier_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// JSONToPolicy provides a mock function with given fields: path
|
||||
func (_m *Verifier) JSONToPolicy(path string) error {
|
||||
ret := _m.Called(path)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for JSONToPolicy")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string) error); ok {
|
||||
r0 = rf(path)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_JSONToPolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'JSONToPolicy'
|
||||
type Verifier_JSONToPolicy_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// JSONToPolicy is a helper method to define mock.On call
|
||||
// - path string
|
||||
func (_e *Verifier_Expecter) JSONToPolicy(path interface{}) *Verifier_JSONToPolicy_Call {
|
||||
return &Verifier_JSONToPolicy_Call{Call: _e.mock.On("JSONToPolicy", path)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_JSONToPolicy_Call) Run(run func(path string)) *Verifier_JSONToPolicy_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_JSONToPolicy_Call) Return(_a0 error) *Verifier_JSONToPolicy_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_JSONToPolicy_Call) RunAndReturn(run func(string) error) *Verifier_JSONToPolicy_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifTeeAttestation provides a mock function with given fields: report, teeNonce
|
||||
func (_m *Verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
ret := _m.Called(report, teeNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifTeeAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
|
||||
r0 = rf(report, teeNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_VerifTeeAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifTeeAttestation'
|
||||
type Verifier_VerifTeeAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifTeeAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - teeNonce []byte
|
||||
func (_e *Verifier_Expecter) VerifTeeAttestation(report interface{}, teeNonce interface{}) *Verifier_VerifTeeAttestation_Call {
|
||||
return &Verifier_VerifTeeAttestation_Call{Call: _e.mock.On("VerifTeeAttestation", report, teeNonce)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifTeeAttestation_Call) Run(run func(report []byte, teeNonce []byte)) *Verifier_VerifTeeAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifTeeAttestation_Call) Return(_a0 error) *Verifier_VerifTeeAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifTeeAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Verifier_VerifTeeAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifVTpmAttestation provides a mock function with given fields: report, vTpmNonce
|
||||
func (_m *Verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
ret := _m.Called(report, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifVTpmAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
|
||||
r0 = rf(report, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_VerifVTpmAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifVTpmAttestation'
|
||||
type Verifier_VerifVTpmAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifVTpmAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Verifier_Expecter) VerifVTpmAttestation(report interface{}, vTpmNonce interface{}) *Verifier_VerifVTpmAttestation_Call {
|
||||
return &Verifier_VerifVTpmAttestation_Call{Call: _e.mock.On("VerifVTpmAttestation", report, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifVTpmAttestation_Call) Run(run func(report []byte, vTpmNonce []byte)) *Verifier_VerifVTpmAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifVTpmAttestation_Call) Return(_a0 error) *Verifier_VerifVTpmAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifVTpmAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Verifier_VerifVTpmAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyAttestation provides a mock function with given fields: report, teeNonce, vTpmNonce
|
||||
func (_m *Verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
ret := _m.Called(report, teeNonce, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte, []byte) error); ok {
|
||||
r0 = rf(report, teeNonce, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_VerifyAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAttestation'
|
||||
type Verifier_VerifyAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - teeNonce []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Verifier_Expecter) VerifyAttestation(report interface{}, teeNonce interface{}, vTpmNonce interface{}) *Verifier_VerifyAttestation_Call {
|
||||
return &Verifier_VerifyAttestation_Call{Call: _e.mock.On("VerifyAttestation", report, teeNonce, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyAttestation_Call) Run(run func(report []byte, teeNonce []byte, vTpmNonce []byte)) *Verifier_VerifyAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte), args[2].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyAttestation_Call) Return(_a0 error) *Verifier_VerifyAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyAttestation_Call) RunAndReturn(run func([]byte, []byte, []byte) error) *Verifier_VerifyAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewVerifier creates a new instance of Verifier. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewVerifier(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Verifier {
|
||||
mock := &Verifier{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -7,11 +7,10 @@
|
||||
package quoteprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -19,51 +18,361 @@ import (
|
||||
)
|
||||
|
||||
func TestFillInAttestationLocal(t *testing.T) {
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer func() {
|
||||
os.Setenv("HOME", originalHome)
|
||||
}()
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "test_home")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
cocosDir := tempDir + "/.cocos/Milan"
|
||||
os.Setenv("HOME", tempDir)
|
||||
|
||||
cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan)
|
||||
err = os.MkdirAll(cocosDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
bundleContent := []byte("mock ASK ARK bundle")
|
||||
err = os.WriteFile(cocosDir+"/ask_ark.pem", bundleContent, 0o644)
|
||||
bundlePath := path.Join(cocosDir, caBundleName)
|
||||
err = os.WriteFile(bundlePath, bundleContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
Policy: &check.Policy{},
|
||||
config := &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
err error
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
setupFunc func()
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Empty attestation",
|
||||
name: "Empty attestation - creates new chain",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
CertificateChain: nil,
|
||||
},
|
||||
err: nil,
|
||||
setupFunc: func() {},
|
||||
expectedError: true,
|
||||
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
|
||||
},
|
||||
{
|
||||
name: "Attestation with existing chain",
|
||||
name: "Attestation with existing chain - no changes needed",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("existing ASK cert"),
|
||||
ArkCert: []byte("existing ARK cert"),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
setupFunc: func() {},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Attestation with empty chain - tries to load from file",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
setupFunc: func() {},
|
||||
expectedError: true,
|
||||
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
|
||||
},
|
||||
{
|
||||
name: "No bundle file exists - no error",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
setupFunc: func() {
|
||||
os.Remove(bundlePath)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := fillInAttestationLocal(tt.attestation, &config)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
os.Setenv("HOME", tempDir)
|
||||
if _, err := os.Stat(bundlePath); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write bundle file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
tt.setupFunc()
|
||||
|
||||
err := fillInAttestationLocal(tt.attestation, config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProductName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
product string
|
||||
expected sevsnp.SevProduct_SevProductName
|
||||
}{
|
||||
{
|
||||
name: "Milan product",
|
||||
product: sevSnpProductMilan,
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
{
|
||||
name: "Genoa product",
|
||||
product: sevSnpProductGenoa,
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
|
||||
},
|
||||
{
|
||||
name: "Unknown product",
|
||||
product: "UnknownProduct",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
{
|
||||
name: "Empty product",
|
||||
product: "",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive - milan lowercase",
|
||||
product: "milan",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetProductName(tt.product)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
config *check.Config
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Invalid product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: "InvalidProduct",
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "product name must be",
|
||||
},
|
||||
{
|
||||
name: "Valid Milan product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
{
|
||||
name: "Valid Genoa product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductGenoa,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
{
|
||||
name: "Config with existing product policy",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{
|
||||
Product: &sevsnp.SevProduct{
|
||||
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := verifyReport(tt.attestation, tt.config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
config *check.Config
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Basic validation test",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Policy: 196608,
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation validation failed",
|
||||
},
|
||||
{
|
||||
name: "Validation with report data",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Policy: 196608,
|
||||
ReportData: []byte("test report datatest report datatest report datatest report data"),
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation validation failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateReport(tt.attestation, tt.config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reportData []byte
|
||||
vmpl uint
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Report data too large",
|
||||
reportData: make([]byte, Nonce+1),
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Valid report data size",
|
||||
reportData: make([]byte, 32),
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Maximum valid report data size",
|
||||
reportData: make([]byte, Nonce),
|
||||
vmpl: 1,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Empty report data",
|
||||
reportData: []byte{},
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := FetchAttestation(tt.reportData, tt.vmpl)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, result)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLeveledQuoteProvider(t *testing.T) {
|
||||
t.Run("GetLeveledQuoteProvider call", func(t *testing.T) {
|
||||
provider, err := GetLeveledQuoteProvider()
|
||||
|
||||
if err != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,6 +45,14 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
|
||||
if teeNonce == nil {
|
||||
return nil, errors.New("tee nonce is required for TDX attestation")
|
||||
}
|
||||
|
||||
if len(teeNonce) != 64 {
|
||||
return nil, fmt.Errorf("invalid tee nonce length: expected 64 bytes, got %d bytes", len(teeNonce))
|
||||
}
|
||||
|
||||
quoteprovider, err := client.GetQuoteProvider()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errOpenTDXDevice)
|
||||
|
||||
@@ -0,0 +1,622 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package tdx
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-tdx-guest/proto/checkconfig"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want attestation.Provider
|
||||
}{
|
||||
{
|
||||
name: "should create new provider successfully",
|
||||
want: provider{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewProvider()
|
||||
assert.IsType(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Attestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should handle empty nonces",
|
||||
teeNonce: []byte{},
|
||||
vTpmNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "invalid tee nonce length: expected 64 bytes, got 0 bytes",
|
||||
},
|
||||
{
|
||||
name: "should handle valid nonces",
|
||||
teeNonce: []byte("test-noncetest-noncetest-noncetest-noncetest-noncetest-noncetest"),
|
||||
vTpmNonce: []byte("vtpm-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "/sys/kernel/config/tsm/report",
|
||||
},
|
||||
{
|
||||
name: "should handle nil nonces",
|
||||
teeNonce: nil,
|
||||
vTpmNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "tee nonce is required for TDX attestation",
|
||||
},
|
||||
{
|
||||
name: "should handle large nonce",
|
||||
teeNonce: make([]byte, 64),
|
||||
vTpmNonce: make([]byte, 32),
|
||||
wantErr: true,
|
||||
errContains: "/sys/kernel/config/tsm/report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := provider{}
|
||||
got, err := p.Attestation(tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_TeeAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should handle empty nonce",
|
||||
teeNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "invalid tee nonce length: expected 64 bytes, got 0 bytes",
|
||||
},
|
||||
{
|
||||
name: "should handle valid nonce",
|
||||
teeNonce: []byte("test-noncetest-noncetest-noncetest-noncetest-noncetest-noncetest"),
|
||||
wantErr: true,
|
||||
errContains: "/sys/kernel/config/tsm/report",
|
||||
},
|
||||
{
|
||||
name: "should handle nil nonce",
|
||||
teeNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "tee nonce is required for TDX attestation",
|
||||
},
|
||||
{
|
||||
name: "should handle 64-byte nonce",
|
||||
teeNonce: make([]byte, 64),
|
||||
wantErr: true,
|
||||
errContains: "/sys/kernel/config/tsm/report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := provider{}
|
||||
got, err := p.TeeAttestation(tt.teeNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_VTpmAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should return error for empty nonce",
|
||||
vTpmNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "vTPM attestation fetch is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for valid nonce",
|
||||
vTpmNonce: []byte("vtpm-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "vTPM attestation fetch is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for nil nonce",
|
||||
vTpmNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "vTPM attestation fetch is not supported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := provider{}
|
||||
got, err := p.VTpmAttestation(tt.vTpmNonce)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureAttestationToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should return error for empty nonce",
|
||||
tokenNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "Azure attestation token is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for valid nonce",
|
||||
tokenNonce: []byte("token-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "Azure attestation token is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for nil nonce",
|
||||
tokenNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "Azure attestation token is not supported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := provider{}
|
||||
got, err := p.AzureAttestationToken(tt.tokenNonce)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVerifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "should create new verifier successfully",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewVerifier()
|
||||
v, ok := got.(verifier)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, v.Policy)
|
||||
assert.NotNil(t, v.Policy.RootOfTrust)
|
||||
assert.NotNil(t, v.Policy.Policy)
|
||||
assert.NotNil(t, v.Policy.Policy.HeaderPolicy)
|
||||
assert.NotNil(t, v.Policy.Policy.TdQuoteBodyPolicy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVerifierWithPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *checkconfig.Config
|
||||
}{
|
||||
{
|
||||
name: "should create verifier with nil policy",
|
||||
policy: nil,
|
||||
},
|
||||
{
|
||||
name: "should create verifier with valid policy",
|
||||
policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "should create verifier with empty policy",
|
||||
policy: &checkconfig.Config{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewVerifierWithPolicy(tt.policy)
|
||||
v, ok := got.(verifier)
|
||||
assert.True(t, ok)
|
||||
|
||||
if tt.policy == nil {
|
||||
assert.NotNil(t, v.Policy)
|
||||
assert.NotNil(t, v.Policy.RootOfTrust)
|
||||
assert.NotNil(t, v.Policy.Policy)
|
||||
} else {
|
||||
assert.Equal(t, tt.policy, v.Policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifTeeAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier verifier
|
||||
report []byte
|
||||
teeNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should return error when policy is nil",
|
||||
verifier: verifier{
|
||||
Policy: nil,
|
||||
},
|
||||
report: []byte("test-report"),
|
||||
teeNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "tdx policy is not provided",
|
||||
},
|
||||
{
|
||||
name: "should handle invalid report format",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
report: []byte("invalid-report"),
|
||||
teeNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should handle empty report",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
report: []byte{},
|
||||
teeNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should handle nil report",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
report: nil,
|
||||
teeNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.verifier.VerifTeeAttestation(tt.report, tt.teeNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifVTpmAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier verifier
|
||||
report []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should return error for any input",
|
||||
verifier: verifier{},
|
||||
report: []byte("test-report"),
|
||||
vTpmNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "VTPM attestation verification is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for empty inputs",
|
||||
verifier: verifier{},
|
||||
report: []byte{},
|
||||
vTpmNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "VTPM attestation verification is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for nil inputs",
|
||||
verifier: verifier{},
|
||||
report: nil,
|
||||
vTpmNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "VTPM attestation verification is not supported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.verifier.VerifVTpmAttestation(tt.report, tt.vTpmNonce)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifyAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier verifier
|
||||
report []byte
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should delegate to VerifTeeAttestation with nil policy",
|
||||
verifier: verifier{
|
||||
Policy: nil,
|
||||
},
|
||||
report: []byte("test-report"),
|
||||
teeNonce: []byte("test-nonce"),
|
||||
vTpmNonce: []byte("vtpm-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "tdx policy is not provided",
|
||||
},
|
||||
{
|
||||
name: "should delegate to VerifTeeAttestation with valid policy",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
report: []byte("invalid-report"),
|
||||
teeNonce: []byte("test-nonce"),
|
||||
vTpmNonce: []byte("vtpm-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.verifier.VerifyAttestation(tt.report, tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_JSONToPolicy(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
testPolicy := &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{
|
||||
HeaderPolicy: &checkconfig.HeaderPolicy{},
|
||||
TdQuoteBodyPolicy: &checkconfig.TDQuoteBodyPolicy{},
|
||||
},
|
||||
}
|
||||
|
||||
validPolicyJSON, err := protojson.Marshal(testPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
validPolicyFile := filepath.Join(tempDir, "valid_policy.json")
|
||||
err = os.WriteFile(validPolicyFile, validPolicyJSON, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidPolicyFile := filepath.Join(tempDir, "invalid_policy.json")
|
||||
err = os.WriteFile(invalidPolicyFile, []byte("invalid json"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier verifier
|
||||
path string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should load valid policy file",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{},
|
||||
},
|
||||
path: validPolicyFile,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return error for non-existent file",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{},
|
||||
},
|
||||
path: filepath.Join(tempDir, "non_existent.json"),
|
||||
wantErr: true,
|
||||
errContains: "no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "should return error for invalid JSON",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{},
|
||||
},
|
||||
path: invalidPolicyFile,
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should return error for empty path",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{},
|
||||
},
|
||||
path: "",
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.verifier.JSONToPolicy(tt.path)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTDXAttestationPolicy(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
testPolicy := &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{
|
||||
HeaderPolicy: &checkconfig.HeaderPolicy{},
|
||||
TdQuoteBodyPolicy: &checkconfig.TDQuoteBodyPolicy{},
|
||||
},
|
||||
}
|
||||
|
||||
validPolicyJSON, err := protojson.Marshal(testPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
validPolicyFile := filepath.Join(tempDir, "valid_policy.json")
|
||||
err = os.WriteFile(validPolicyFile, validPolicyJSON, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidPolicyFile := filepath.Join(tempDir, "invalid_policy.json")
|
||||
err = os.WriteFile(invalidPolicyFile, []byte("invalid json"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
emptyFile := filepath.Join(tempDir, "empty.json")
|
||||
err = os.WriteFile(emptyFile, []byte{}, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
policyPath string
|
||||
policy *checkconfig.Config
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should read valid policy file",
|
||||
policyPath: validPolicyFile,
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return error for non-existent file",
|
||||
policyPath: filepath.Join(tempDir, "non_existent.json"),
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: true,
|
||||
errContains: "no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "should return error for invalid JSON",
|
||||
policyPath: invalidPolicyFile,
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should return error for empty file",
|
||||
policyPath: emptyFile,
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should return error for empty path",
|
||||
policyPath: "",
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ReadTDXAttestationPolicy(tt.policyPath, tt.policy)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tt.policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,11 @@
|
||||
package vtpm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -13,6 +17,8 @@ import (
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
ptpm "github.com/google/go-tpm-tools/proto/tpm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
@@ -24,6 +30,633 @@ const sevSnpProductMilan = "Milan"
|
||||
|
||||
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
|
||||
|
||||
type mockTPM struct {
|
||||
*bytes.Buffer
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (m *mockTPM) Close() error {
|
||||
return m.closeErr
|
||||
}
|
||||
|
||||
type mockWriter struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockWriter) Write(p []byte) (n int, err error) {
|
||||
if m.err != nil {
|
||||
return 0, m.err
|
||||
}
|
||||
m.data = append(m.data, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestOpenTpm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
externalTPM io.ReadWriteCloser
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "External TPM available",
|
||||
externalTPM: &mockTPM{Buffer: &bytes.Buffer{}},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "No external TPM",
|
||||
externalTPM: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalExternalTPM := ExternalTPM
|
||||
defer func() { ExternalTPM = originalExternalTPM }()
|
||||
|
||||
ExternalTPM = tt.externalTPM
|
||||
|
||||
tpm, err := OpenTpm()
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
if tt.externalTPM != nil {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tpm)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTpmEventLog(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "event_log")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
testData := []byte("test event log data")
|
||||
_, err = tempFile.Write(testData)
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
|
||||
tpm := &tpm{ReadWriteCloser: &mockTPM{Buffer: &bytes.Buffer{}}}
|
||||
|
||||
_, err = tpm.EventLog()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeAttestation bool
|
||||
vmpl uint
|
||||
}{
|
||||
{
|
||||
name: "TEE attestation enabled",
|
||||
teeAttestation: true,
|
||||
vmpl: 1,
|
||||
},
|
||||
{
|
||||
name: "TEE attestation disabled",
|
||||
teeAttestation: false,
|
||||
vmpl: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider := NewProvider(tt.teeAttestation, tt.vmpl)
|
||||
assert.NotNil(t, provider)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderAzureAttestationToken(t *testing.T) {
|
||||
provider := NewProvider(false, 0)
|
||||
|
||||
token, err := provider.AzureAttestationToken([]byte("test-nonce"))
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, token)
|
||||
assert.Contains(t, err.Error(), "Azure attestation token is not supported")
|
||||
}
|
||||
|
||||
func TestNewVerifier(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
verifier := NewVerifier(writer)
|
||||
|
||||
assert.NotNil(t, verifier)
|
||||
}
|
||||
|
||||
func TestNewVerifierWithPolicy(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *attestation.Config
|
||||
}{
|
||||
{
|
||||
name: "With policy",
|
||||
policy: policy,
|
||||
},
|
||||
{
|
||||
name: "Without policy (nil)",
|
||||
policy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verifier := NewVerifierWithPolicy([]byte("test-key"), writer, tt.policy)
|
||||
assert.NotNil(t, verifier)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalQuote(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *attest.Attestation
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation",
|
||||
attestation: &attest.Attestation{
|
||||
AkPub: []byte("test-key"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nil attestation",
|
||||
attestation: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := marshalQuote(tt.attestation)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, data)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.attestation != nil {
|
||||
assert.NotEmpty(t, data)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpectedPCRValues(t *testing.T) {
|
||||
testPCRValue := make([]byte, 32)
|
||||
for i := range testPCRValue {
|
||||
testPCRValue[i] = byte(i)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *attest.Attestation
|
||||
policy *attestation.Config
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Matching PCR values SHA256",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"0": hex.EncodeToString(testPCRValue),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Mismatched PCR values",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"0": hex.EncodeToString(make([]byte, 32)),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "expected",
|
||||
},
|
||||
{
|
||||
name: "Unsupported hash algorithm",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_HASH_INVALID,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "hash algo is not supported",
|
||||
},
|
||||
{
|
||||
name: "Invalid PCR index",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"invalid": hex.EncodeToString(testPCRValue),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "error converting PCR index to int32",
|
||||
},
|
||||
{
|
||||
name: "Invalid PCR value hex",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"0": "invalid-hex",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "error converting PCR value to byte",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkExpectedPCRValues(tt.attestation, tt.policy)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPolicy(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "policy_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
validPolicy := map[string]interface{}{
|
||||
"policy": map[string]interface{}{
|
||||
"product": map[string]interface{}{
|
||||
"name": "test-product",
|
||||
},
|
||||
},
|
||||
"rootOfTrust": map[string]interface{}{
|
||||
"productLine": "test-line",
|
||||
},
|
||||
"pcrConfig": map[string]interface{}{
|
||||
"pcrValues": map[string]interface{}{
|
||||
"sha256": map[string]string{
|
||||
"0": "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
validPolicyData, err := json.Marshal(validPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
validPolicyPath := filepath.Join(tempDir, "valid_policy.json")
|
||||
err = os.WriteFile(validPolicyPath, validPolicyData, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
policyPath string
|
||||
expectError bool
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Valid policy file",
|
||||
policyPath: validPolicyPath,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Non-existent policy file",
|
||||
policyPath: "/nonexistent/path",
|
||||
expectError: true,
|
||||
expectedErr: ErrAttestationPolicyOpen,
|
||||
},
|
||||
{
|
||||
name: "Empty policy path",
|
||||
policyPath: "",
|
||||
expectError: true,
|
||||
expectedErr: ErrAttestationPolicyMissing,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
err := ReadPolicy(tt.policyPath, config)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != nil {
|
||||
assert.True(t, errors.Contains(err, tt.expectedErr))
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPolicyFromByte(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policyData []byte
|
||||
expectError bool
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Valid policy data",
|
||||
policyData: []byte(`{
|
||||
"policy": {
|
||||
"product": {
|
||||
"name": "test-product"
|
||||
}
|
||||
},
|
||||
"rootOfTrust": {
|
||||
"productLine": "test-line"
|
||||
},
|
||||
"pcrConfig": {
|
||||
"pcrValues": {
|
||||
"sha256": {
|
||||
"0": "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
policyData: []byte(`{invalid json`),
|
||||
expectError: true,
|
||||
expectedErr: ErrAttestationPolicyDecode,
|
||||
},
|
||||
{
|
||||
name: "Empty policy data",
|
||||
policyData: []byte(``),
|
||||
expectError: true,
|
||||
expectedErr: ErrAttestationPolicyDecode,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
err := ReadPolicyFromByte(tt.policyData, config)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != nil {
|
||||
assert.True(t, errors.Contains(err, tt.expectedErr))
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertPolicyToJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *attestation.Config
|
||||
expectError bool
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Valid config",
|
||||
config: &attestation.Config{
|
||||
Config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Product: &sevsnp.SevProduct{
|
||||
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: "Milan",
|
||||
},
|
||||
},
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"0": "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nil config",
|
||||
config: &attestation.Config{
|
||||
Config: nil,
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
},
|
||||
expectError: false,
|
||||
expectedErr: ErrProtoMarshalFailed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
jsonData, err := ConvertPolicyToJSON(tt.config)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != nil {
|
||||
assert.True(t, errors.Contains(err, tt.expectedErr))
|
||||
}
|
||||
assert.Nil(t, jsonData)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jsonData)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(jsonData, &result)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVTPMVerify(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
quote []byte
|
||||
teeNonce []byte
|
||||
vtpmNonce []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid quote data",
|
||||
quote: []byte("invalid"),
|
||||
teeNonce: []byte("tee-nonce"),
|
||||
vtpmNonce: []byte("vtpm-nonce"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty quote",
|
||||
quote: []byte{},
|
||||
teeNonce: []byte("tee-nonce"),
|
||||
vtpmNonce: []byte("vtpm-nonce"),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VTPMVerify(tt.quote, tt.teeNonce, tt.vtpmNonce, writer, policy)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyQuote(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
quote []byte
|
||||
vtpmNonce []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid quote data",
|
||||
quote: []byte("invalid"),
|
||||
vtpmNonce: []byte("vtpm-nonce"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty quote",
|
||||
quote: []byte{},
|
||||
vtpmNonce: []byte("vtpm-nonce"),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VerifyQuote(tt.quote, tt.vtpmNonce, writer, policy)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterError(t *testing.T) {
|
||||
writer := &mockWriter{err: fmt.Errorf("write error")}
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
err := VerifyQuote([]byte("invalid"), []byte("nonce"), writer, policy)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "policy")
|
||||
require.NoError(t, err)
|
||||
@@ -33,7 +666,7 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
err = setAttestationPolicy(attestationPB, tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Change random data so in the signature so the signature failes
|
||||
// Change random data so in the signature so the signature fails
|
||||
attestationPB.Report.Signature[0] = attestationPB.Report.Signature[0] ^ 0x01
|
||||
|
||||
tests := []struct {
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cvm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/health"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
type TestServer struct {
|
||||
agent.UnimplementedAgentServiceServer
|
||||
server *grpc.Server
|
||||
health *health.Server
|
||||
port int
|
||||
listenAddr string
|
||||
}
|
||||
|
||||
func NewTestServer() (*TestServer, error) {
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
addr := listener.Addr().(*net.TCPAddr)
|
||||
|
||||
server := grpc.NewServer()
|
||||
healthServer := health.NewServer()
|
||||
|
||||
ts := &TestServer{
|
||||
server: server,
|
||||
health: healthServer,
|
||||
port: addr.Port,
|
||||
listenAddr: fmt.Sprintf("localhost:%d", addr.Port),
|
||||
}
|
||||
|
||||
svc := new(mocks.Service)
|
||||
agent.RegisterAgentServiceServer(server, agentgrpc.NewServer(svc))
|
||||
grpchealth.RegisterHealthServer(server, healthServer)
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil {
|
||||
fmt.Printf("Server exited with error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
healthServer.SetServingStatus("agent", grpchealth.HealthCheckResponse_SERVING)
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func (s *TestServer) Stop() {
|
||||
if s.server != nil {
|
||||
s.server.GracefulStop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentClientIntegration(t *testing.T) {
|
||||
testServer, err := NewTestServer()
|
||||
require.NoError(t, err)
|
||||
defer testServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
serverRunning bool
|
||||
config pkggrpc.CVMClientConfig
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "successful connection",
|
||||
serverRunning: true,
|
||||
config: pkggrpc.CVMClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "server not healthy",
|
||||
serverRunning: false,
|
||||
config: pkggrpc.CVMClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
URL: "",
|
||||
Timeout: 1,
|
||||
},
|
||||
},
|
||||
err: errors.New("failed to connect to grpc server"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if !tt.serverRunning {
|
||||
testServer.health.SetServingStatus("agent", grpchealth.HealthCheckResponse_NOT_SERVING)
|
||||
} else {
|
||||
testServer.health.SetServingStatus("agent", grpchealth.HealthCheckResponse_SERVING)
|
||||
}
|
||||
|
||||
client, agentClient, err := NewCVMClient(tt.config)
|
||||
assert.True(t, errors.Contains(err, tt.err))
|
||||
if err != nil {
|
||||
assert.Nil(t, client)
|
||||
assert.Nil(t, agentClient)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, client)
|
||||
require.NotNil(t, agentClient)
|
||||
defer client.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -529,6 +529,76 @@ func TestReceiveAttestation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceiverIMAMeasurements(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
totalSize int
|
||||
chunks [][]byte
|
||||
wantResult []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "successful single chunk receive",
|
||||
description: "Receiving IMA measurements",
|
||||
totalSize: 20,
|
||||
chunks: [][]byte{[]byte("12345678912345678999")},
|
||||
wantResult: []byte("12345678912345678999"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "stream error",
|
||||
description: "Receiving IMA measurements",
|
||||
totalSize: 20,
|
||||
chunks: [][]byte{[]byte("12345678912345678999")},
|
||||
wantResult: nil,
|
||||
wantErr: errors.New("stream error"),
|
||||
},
|
||||
{
|
||||
name: "size mismatch",
|
||||
description: "Receiving IMA measurements",
|
||||
totalSize: 10,
|
||||
chunks: [][]byte{[]byte("12345678912345678999")},
|
||||
wantResult: nil,
|
||||
wantErr: errors.New("progress update exceeds total bytes: attempted to add 20 bytes, but only 10 bytes remain"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStream := new(mocks.AgentService_IMAMeasurementsClient[agent.IMAMeasurementsResponse])
|
||||
|
||||
p := New(true)
|
||||
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
||||
|
||||
resultFile, err := os.CreateTemp("", "test_ima_measurements")
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(resultFile.Name())
|
||||
})
|
||||
|
||||
if tt.wantErr != nil {
|
||||
mockStream.On("Recv").Return(nil, tt.wantErr).Once()
|
||||
}
|
||||
mockStream.On("Recv").Return(&agent.IMAMeasurementsResponse{Pcr10: []byte(tt.chunks[0]), File: []byte(tt.chunks[0])}, nil).Once()
|
||||
mockStream.On("Recv").Return(nil, io.EOF).Once()
|
||||
|
||||
pcr10, err := p.ReceiveIMAMeasurements(tt.description, tt.totalSize, mockStream, resultFile)
|
||||
|
||||
assert.NoError(t, resultFile.Close())
|
||||
|
||||
if tt.wantErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, pcr10)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockAlgoStream struct {
|
||||
stream agent.AgentService_AlgoClient
|
||||
sendCount int
|
||||
|
||||
+29
-29
@@ -176,6 +176,35 @@ func (sdk *agentSDK) AttestationResult(ctx context.Context, nonce [size32]byte,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) IMAMeasurements(ctx context.Context, resultFile *os.File) ([]byte, error) {
|
||||
request := &agent.IMAMeasurementsRequest{}
|
||||
|
||||
stream, err := sdk.client.IMAMeasurements(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
incomingmd, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||
|
||||
if len(fileSizeStr) == 0 {
|
||||
fileSizeStr = append(fileSizeStr, "0")
|
||||
}
|
||||
|
||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pb := progressbar.New(true)
|
||||
|
||||
return pb.ReceiveIMAMeasurements(imaMeasurementsProgressDescription, fileSize, stream, resultFile)
|
||||
}
|
||||
|
||||
func signData(userID string, privKey crypto.Signer) ([]byte, error) {
|
||||
var signature []byte
|
||||
var err error
|
||||
@@ -208,32 +237,3 @@ func generateMetadata(userID string, privateKey crypto.PrivateKey) (metadata.MD,
|
||||
kv[auth.SignatureMetadataKey] = base64.StdEncoding.EncodeToString(signature)
|
||||
return metadata.New(kv), nil
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) IMAMeasurements(ctx context.Context, resultFile *os.File) ([]byte, error) {
|
||||
request := &agent.IMAMeasurementsRequest{}
|
||||
|
||||
stream, err := sdk.client.IMAMeasurements(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
incomingmd, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||
|
||||
if len(fileSizeStr) == 0 {
|
||||
fileSizeStr = append(fileSizeStr, "0")
|
||||
}
|
||||
|
||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pb := progressbar.New(true)
|
||||
|
||||
return pb.ReceiveIMAMeasurements(imaMeasurementsProgressDescription, fileSize, stream, resultFile)
|
||||
}
|
||||
|
||||
@@ -558,6 +558,82 @@ func TestAttestationResult(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIMAMeasurements(t *testing.T) {
|
||||
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := agent.NewAgentServiceClient(conn)
|
||||
|
||||
sdk := sdk.NewAgentSDK(client)
|
||||
|
||||
response := &agent.IMAMeasurementsResponse{
|
||||
File: []byte{
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
response *agent.IMAMeasurementsResponse
|
||||
svcRes []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "fetch IMA measurements successfully",
|
||||
response: response,
|
||||
svcRes: response.File,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "failed to fetch IMA measurements",
|
||||
response: &agent.IMAMeasurementsResponse{File: []byte{}},
|
||||
svcRes: nil,
|
||||
err: errors.New("failed to fetch IMA measurements"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("IMAMeasurements", mock.Anything).Return(tc.svcRes, tc.svcRes, tc.err)
|
||||
|
||||
file, err := os.CreateTemp("", "ima_measurements")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(file.Name())
|
||||
})
|
||||
|
||||
_, err = sdk.IMAMeasurements(context.Background(), file)
|
||||
|
||||
require.NoError(t, file.Close())
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("Expected gRPC status error, but got: %v", err)
|
||||
}
|
||||
|
||||
if tc.err != nil {
|
||||
if st.Message() != tc.err.Error() {
|
||||
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
|
||||
}
|
||||
}
|
||||
|
||||
res, err := os.ReadFile(file.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.response.File, res, tc.name)
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateKeys(t *testing.T, keyType string) (priv any, pub []byte) {
|
||||
switch keyType {
|
||||
case "ecdsa":
|
||||
|
||||
Reference in New Issue
Block a user