COCOS-460 - Restore test coverage to 65% (#465)
CI / ci (push) Has been cancelled

* Implement IMAMeasurements method in agentSDK and add corresponding unit tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for NewIMAMeasurements command in CLI

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add error assertion for command execution in NewIMAMeasurements test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix nil pointer dereference in Close method and update NewCreateVMCmd logic for manager client initialization

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor file permission settings to use octal notation and improve cleanup handling in NewCreateVMCmd test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive unit tests for state machine functionality

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add mock implementation for Algorithm interface and corresponding test cases

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor file permission settings to use octal notation in TestStopComputationIntegration

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove redundant reset test cases from TestStateMachine_Reset

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix race condition in action call verification in TestStateMachine_HandleEvent

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance state machine with reset functionality and improve thread safety in event handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Improve error handling in state machine start function during tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove concurrent reset and send event test from state machine tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove error logging for Start function in transition tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add mock implementations for AgentService_IMAMeasurementsClient and Service Shutdown method; enhance progress tests for IMA measurements handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive tests for FileStorage functionality including loading, saving, and concurrent access

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance tests by adding dataset and algorithm hashes in handleRunReqChunks; improve error handling in TestFileStorage_ErrorHandling cleanup

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance TestManagerClient_Process by adding new test cases for Agent state and Disconnect requests; update setupMocks to include grpcClient

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix graceful shutdown in gRPC server by adding nil checks for health and server instances

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance TestAttestation by adding mock expectations for VTpmAttestation and Attestation methods; update service call to include platform parameter

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance gRPC Server by adding synchronization for start/stop methods; prevent multiple starts and ensure graceful shutdown

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for gRPC server methods including VM creation, removal, and info retrieval

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add tests for SEVSNP and TDX host capabilities; remove unused vsock code

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add a newline for better readability in vm_test.go

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add integration tests for gRPC client in cvm_test.go

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove unused vsock dependencies and add comprehensive unit tests for GCP attestation functions

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Skip GCP tests if credentials are not set

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add tests for error handling in attestation configuration and GCP commands

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Improve error handling in Azure VM test response writing

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Skip tests in GCP functions if credentials are not set

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive unit tests for Azure attestation provider and verifier

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for TPM functionality and improve error handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive tests for attestation functionality and improve error handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add validation for teeNonce in TeeAttestation and implement comprehensive tests for provider methods

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor error messages in TDX attestation tests for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix error message in TeeAttestation test for valid nonce case

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add MeasurementProvider mock and update mockery configuration

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add logging for product in parseUints and rename test functions for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor TestSevsnpverify to reset configuration and improve error logging

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2025-07-25 16:35:37 +03:00
committed by GitHub
parent 85a2b7a6c8
commit 4e8057f481
43 changed files with 9194 additions and 321 deletions
+354
View File
@@ -3,19 +3,33 @@
package atls
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
certssdk "github.com/absmach/certs/sdk"
"github.com/absmach/magistrala/pkg/errors"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"golang.org/x/crypto/sha3"
"google.golang.org/protobuf/encoding/protojson"
)
@@ -217,6 +231,346 @@ func TestGetPlatformTypeFromOID(t *testing.T) {
}
}
func TestVerifyCertificateExtension(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
teeNonce := append(pubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
cases := []struct {
name string
extension []byte
pubKey []byte
nonce []byte
platformType attestation.PlatformType
expectError bool
}{
{
name: "Valid extension with SNPvTPM",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true,
},
{
name: "Invalid platform type",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: nonce,
platformType: 999,
expectError: true,
},
{
name: "Empty extension",
extension: []byte{},
pubKey: pubKeyDER,
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true,
},
{
name: "Empty public key",
extension: hashNonce[:],
pubKey: []byte{},
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true,
},
{
name: "Empty nonce",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: []byte{},
platformType: attestation.SNPvTPM,
expectError: true,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
err := VerifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
if c.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestGetCertificateExtension(t *testing.T) {
mockProvider := new(mocks.Provider)
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation-data"), nil)
pubKey := []byte("test-public-key")
nonce := make([]byte, 32)
_, err := rand.Read(nonce)
require.NoError(t, err)
testOID := asn1.ObjectIdentifier{1, 2, 3, 4}
extension, err := getCertificateExtension(mockProvider, pubKey, nonce, testOID)
assert.NoError(t, err)
assert.Equal(t, testOID, extension.Id)
assert.Equal(t, []byte("mock-attestation-data"), extension.Value)
}
func TestGetCertificateWithSelfSigned(t *testing.T) {
getCertFunc := GetCertificate("", "")
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{
ServerName: serverName,
}
cert, err := getCertFunc(clientHello)
if err != nil {
t.Logf("Expected error due to missing attestation setup: %v", err)
assert.Error(t, err)
} else {
assert.NotNil(t, cert)
assert.NotEmpty(t, cert.Certificate)
assert.NotNil(t, cert.PrivateKey)
}
}
func TestGetCertificateWithCA(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
mockCert := certssdk.Certificate{
Certificate: "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIBATANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDDA1UZXN0IENBIFJvb3QwHhcNMjMwMzMxMDAwMDAwWhcNMjQwMzMxMDAwMDAwWjAYMRYwFAYDVQQDDA1UZXN0IENlcnRpZmljYXRlMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEtest-key-data-here\n-----END CERTIFICATE-----",
}
response, _ := json.Marshal(mockCert)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(response); err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return
}
}))
defer mockServer.Close()
getCertFunc := GetCertificate(mockServer.URL, "test-cvm-id")
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{
ServerName: serverName,
}
_, err = getCertFunc(clientHello)
if err != nil {
t.Logf("Expected error due to missing attestation setup: %v", err)
assert.Error(t, err)
}
}
func TestGetCertificateInvalidServerName(t *testing.T) {
getCertFunc := GetCertificate("", "")
cases := []struct {
name string
serverName string
expectErr string
}{
{
name: "Missing .nonce suffix",
serverName: "invalidname",
expectErr: "failed to get platform provider",
},
{
name: "Too short server name",
serverName: "short",
expectErr: "failed to get platform provider",
},
{
name: "Invalid nonce encoding",
serverName: "invalidhex.nonce",
expectErr: "failed to get platform provider",
},
{
name: "Wrong nonce length",
serverName: "deadbeef.nonce",
expectErr: "failed to get platform provider",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
clientHello := &tls.ClientHelloInfo{
ServerName: c.serverName,
}
cert, err := getCertFunc(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), c.expectErr)
assert.Nil(t, cert)
})
}
}
func TestProcessRequest(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/success":
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte(`{"message": "success"}`)); err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return
}
case "/notfound":
w.WriteHeader(http.StatusNotFound)
if _, err := w.Write([]byte(`{"error": "not found"}`)); err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return
}
case "/headers":
if r.Header.Get("X-Custom-Header") == "test-value" {
w.Header().Set("X-Response-Header", "received")
}
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte(`{"headers": "ok"}`)); err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return
}
default:
w.WriteHeader(http.StatusInternalServerError)
}
}))
defer testServer.Close()
cases := []struct {
name string
method string
url string
data []byte
headers map[string]string
expectedRespCodes []int
expectError bool
}{
{
name: "Successful GET request",
method: http.MethodGet,
url: testServer.URL + "/success",
data: nil,
headers: nil,
expectedRespCodes: []int{http.StatusOK},
expectError: false,
},
{
name: "Successful POST request with data",
method: http.MethodPost,
url: testServer.URL + "/success",
data: []byte(`{"test": "data"}`),
headers: nil,
expectedRespCodes: []int{http.StatusOK},
expectError: false,
},
{
name: "Request with custom headers",
method: http.MethodGet,
url: testServer.URL + "/headers",
data: nil,
headers: map[string]string{"X-Custom-Header": "test-value"},
expectedRespCodes: []int{http.StatusOK},
expectError: false,
},
{
name: "Request with unexpected status code",
method: http.MethodGet,
url: testServer.URL + "/notfound",
data: nil,
headers: nil,
expectedRespCodes: []int{http.StatusOK},
expectError: true,
},
{
name: "Request with multiple expected status codes",
method: http.MethodGet,
url: testServer.URL + "/notfound",
data: nil,
headers: nil,
expectedRespCodes: []int{http.StatusOK, http.StatusNotFound},
expectError: false,
},
{
name: "Request to invalid URL",
method: http.MethodGet,
url: "invalid-url",
data: nil,
headers: nil,
expectedRespCodes: []int{http.StatusOK},
expectError: true,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
headers, body, err := processRequest(c.method, c.url, c.data, c.headers, c.expectedRespCodes...)
if c.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, headers)
assert.NotNil(t, body)
if c.name == "Request with custom headers" {
assert.Equal(t, "received", headers.Get("X-Response-Header"))
}
}
})
}
}
func TestGetCertificateExtensionError(t *testing.T) {
mockProvider := new(mocks.Provider)
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("failed to get attestation"))
pubKey := []byte("test-public-key")
nonce := make([]byte, 32)
testOID := asn1.ObjectIdentifier{1, 2, 3, 4}
extension, err := getCertificateExtension(mockProvider, pubKey, nonce, testOID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get attestation")
assert.Equal(t, pkix.Extension{}, extension)
}
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
file, err := os.ReadFile("../../attestation.bin")
require.NoError(t, err)
+213
View File
@@ -0,0 +1,213 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCCPlatform(t *testing.T) {
tests := []struct {
name string
sevSnpGuestExists bool
sevSnpGuestvTPMExists bool
tdxGuestExists bool
isAzure bool
expected PlatformType
}{
{
name: "No CC platform detected",
sevSnpGuestExists: false,
sevSnpGuestvTPMExists: false,
tdxGuestExists: false,
isAzure: false,
expected: NoCC,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := CCPlatform()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSevSnpGuestDeviceExists(t *testing.T) {
tests := []struct {
name string
openDeviceErr error
expected bool
}{
{
name: "device does not exist or fails to open",
openDeviceErr: fmt.Errorf("device not found"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SevSnpGuestDeviceExists()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSevSnpGuestvTPMExists(t *testing.T) {
tests := []struct {
name string
vTPMExists bool
sevSnpExists bool
expected bool
}{
{
name: "vTPM exists but SEV-SNP does not",
vTPMExists: true,
sevSnpExists: false,
expected: false,
},
{
name: "SEV-SNP exists but vTPM does not",
vTPMExists: false,
sevSnpExists: true,
expected: false,
},
{
name: "neither exists",
vTPMExists: false,
sevSnpExists: false,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SevSnpGuestvTPMExists()
assert.Equal(t, tt.expected, result)
})
}
}
func TestVTPMExists(t *testing.T) {
tests := []struct {
name string
openTPMErr error
expected bool
}{
{
name: "TPM fails to open",
openTPMErr: fmt.Errorf("TPM not found"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := vTPMExists()
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsAzureVM(t *testing.T) {
tests := []struct {
name string
vTPMExists bool
statusCode int
responseBody string
httpError error
expected bool
}{
{
name: "Azure VM with empty response body",
vTPMExists: true,
statusCode: http.StatusOK,
responseBody: "",
httpError: nil,
expected: false,
},
{
name: "Azure VM with non-200 status code",
vTPMExists: true,
statusCode: http.StatusNotFound,
responseBody: "",
httpError: nil,
expected: false,
},
{
name: "HTTP request error",
vTPMExists: true,
statusCode: 0,
responseBody: "",
httpError: fmt.Errorf("connection failed"),
expected: false,
},
{
name: "vTPM does not exist",
vTPMExists: false,
statusCode: http.StatusOK,
responseBody: `{"compute":{"name":"test-vm"}}`,
httpError: nil,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, "true", r.Header.Get("Metadata"))
expectedURL := fmt.Sprintf("/?api-version=%s", azureApiVersion)
assert.Equal(t, expectedURL, r.URL.String())
if tt.httpError != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(tt.statusCode)
if tt.responseBody != "" {
if _, err := w.Write([]byte(tt.responseBody)); err != nil {
t.Fatalf("Failed to write response body: %v", err)
}
}
}))
defer server.Close()
if tt.httpError != nil {
server.Close()
}
result := isAzureVM()
assert.Equal(t, tt.expected, result)
})
}
}
func TestTDXGuestDeviceExists(t *testing.T) {
tests := []struct {
name string
openDeviceErr error
expected bool
}{
{
name: "TDX device does not exist or fails to open",
openDeviceErr: fmt.Errorf("device not found"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := TDXGuestDeviceExists()
assert.Equal(t, tt.expected, result)
})
}
}
+578
View File
@@ -0,0 +1,578 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package azure
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-tpm-tools/proto/attest"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/attestation"
"google.golang.org/protobuf/proto"
)
var (
testNonce = []byte("test-nonce-12345678901234567890123456789012")
testReport = []byte("test-report-data")
)
func TestNewProvider(t *testing.T) {
tests := []struct {
name string
want attestation.Provider
}{
{
name: "creates new provider successfully",
want: provider{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewProvider()
assert.Equal(t, tt.want, got)
})
}
}
func TestProvider_Attestation(t *testing.T) {
tests := []struct {
name string
teeNonce []byte
vTpmNonce []byte
wantErr bool
errorMessage string
}{
{
name: "maa parameters error",
teeNonce: testNonce,
vTpmNonce: testNonce,
wantErr: true,
errorMessage: "failed to get report",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewProvider()
result, err := p.Attestation(tt.teeNonce, tt.vTpmNonce)
if tt.wantErr {
assert.Error(t, err)
if tt.errorMessage != "" {
assert.Contains(t, err.Error(), tt.errorMessage)
}
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
}
})
}
}
func TestProvider_TeeAttestation(t *testing.T) {
tests := []struct {
name string
teeNonce []byte
wantErr bool
errorMessage string
}{
{
name: "maa parameters error",
teeNonce: testNonce,
wantErr: true,
errorMessage: "failed to get report",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewProvider()
result, err := p.TeeAttestation(tt.teeNonce)
if tt.wantErr {
assert.Error(t, err)
if tt.errorMessage != "" {
assert.Contains(t, err.Error(), tt.errorMessage)
}
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
}
})
}
}
func TestProvider_AzureAttestationToken(t *testing.T) {
tests := []struct {
name string
tokenNonce []byte
setupServer func() *httptest.Server
wantErr bool
errorMessage string
}{
{
name: "server error",
tokenNonce: testNonce,
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
},
wantErr: true,
errorMessage: "failed to fetch Azure token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := tt.setupServer()
defer server.Close()
originalURL := MaaURL
MaaURL = server.URL
defer func() { MaaURL = originalURL }()
p := NewProvider()
result, err := p.AzureAttestationToken(tt.tokenNonce)
if tt.wantErr {
assert.Error(t, err)
if tt.errorMessage != "" {
assert.Contains(t, err.Error(), tt.errorMessage)
}
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
}
})
}
}
func TestNewVerifier(t *testing.T) {
tests := []struct {
name string
writer io.Writer
}{
{
name: "creates verifier with buffer writer",
writer: &bytes.Buffer{},
},
{
name: "creates verifier with nil writer",
writer: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := NewVerifier(tt.writer)
verifier, ok := v.(verifier)
assert.True(t, ok)
assert.Equal(t, tt.writer, verifier.writer)
assert.NotNil(t, verifier.Policy)
assert.NotNil(t, verifier.Policy.Config)
assert.NotNil(t, verifier.Policy.PcrConfig)
})
}
}
func TestNewVerifierWithPolicy(t *testing.T) {
tests := []struct {
name string
writer io.Writer
policy *attestation.Config
}{
{
name: "creates verifier with custom policy",
writer: &bytes.Buffer{},
policy: &attestation.Config{
Config: &check.Config{
Policy: &check.Policy{},
RootOfTrust: &check.RootOfTrust{},
},
PcrConfig: &attestation.PcrConfig{},
},
},
{
name: "creates verifier with nil policy",
writer: &bytes.Buffer{},
policy: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := NewVerifierWithPolicy(tt.writer, tt.policy)
verifier, ok := v.(verifier)
assert.True(t, ok)
assert.Equal(t, tt.writer, verifier.writer)
assert.NotNil(t, verifier.Policy)
})
}
}
func TestVerifier_VerifTeeAttestation(t *testing.T) {
tests := []struct {
name string
report []byte
teeNonce []byte
wantErr bool
errorMessage string
}{
{
name: "empty report",
report: []byte{},
teeNonce: testNonce,
wantErr: true,
},
{
name: "invalid report format",
report: []byte("invalid-report"),
teeNonce: testNonce,
wantErr: true,
},
{
name: "nil nonce",
report: testReport,
teeNonce: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := NewVerifier(&bytes.Buffer{})
err := v.VerifTeeAttestation(tt.report, tt.teeNonce)
if tt.wantErr {
assert.Error(t, err)
if tt.errorMessage != "" {
assert.Contains(t, err.Error(), tt.errorMessage)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestVerifier_VerifyAttestation(t *testing.T) {
validQuote := &attest.Attestation{
TeeAttestation: &attest.Attestation_SevSnpAttestation{
SevSnpAttestation: &sevsnp.Attestation{
Report: &sevsnp.Report{
HostData: []byte("test-data"),
},
Product: &sevsnp.SevProduct{
Name: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
},
CertificateChain: &sevsnp.CertificateChain{
Extras: make(map[string][]byte),
},
},
},
}
validReport, _ := proto.Marshal(validQuote)
tests := []struct {
name string
report []byte
teeNonce []byte
vTpmNonce []byte
wantErr bool
errorMessage string
}{
{
name: "successful verification",
report: validReport,
teeNonce: testNonce,
vTpmNonce: testNonce,
wantErr: true,
errorMessage: "failed to verify vTPM attestation report",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := NewVerifier(&bytes.Buffer{})
err := v.VerifyAttestation(tt.report, tt.teeNonce, tt.vTpmNonce)
if tt.wantErr {
assert.Error(t, err)
if tt.errorMessage != "" {
assert.Contains(t, err.Error(), tt.errorMessage)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestFetchAzureAttestationToken(t *testing.T) {
tests := []struct {
name string
tokenNonce []byte
maaURL string
setupServer func() *httptest.Server
wantErr bool
errorMessage string
}{
{
name: "server error",
tokenNonce: testNonce,
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
},
wantErr: true,
errorMessage: "error fetching azure token",
},
{
name: "invalid url",
tokenNonce: testNonce,
setupServer: func() *httptest.Server {
return nil
},
wantErr: true,
errorMessage: "error fetching azure token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var url string
if tt.setupServer != nil {
server := tt.setupServer()
if server != nil {
defer server.Close()
url = server.URL
}
}
if tt.name == "invalid url" {
url = "invalid-url"
}
result, err := FetchAzureAttestationToken(tt.tokenNonce, url)
if tt.wantErr {
assert.Error(t, err)
if tt.errorMessage != "" {
assert.Contains(t, err.Error(), tt.errorMessage)
}
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
}
})
}
}
func TestValidateToken(t *testing.T) {
tests := []struct {
name string
token string
setupServer func() *httptest.Server
wantErr bool
errorMessage string
}{
{
name: "invalid token format",
token: "invalid-token",
setupServer: nil,
wantErr: true,
errorMessage: "failed to parse token",
},
{
name: "empty token",
token: "",
setupServer: nil,
wantErr: true,
errorMessage: "failed to parse token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setupServer != nil {
server := tt.setupServer()
defer server.Close()
originalURL := MaaURL
MaaURL = server.URL
defer func() { MaaURL = originalURL }()
}
result, err := validateToken(tt.token)
if tt.wantErr {
assert.Error(t, err)
if tt.errorMessage != "" {
assert.Contains(t, err.Error(), tt.errorMessage)
}
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
}
})
}
}
func TestIntegration_FullAttestationFlow(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
t.Run("full attestation flow with mock server", func(t *testing.T) {
maaServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/attest":
response := map[string]interface{}{
"token": createMockJWT(),
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
t.Fatalf("Failed to encode response: %v", err)
}
case "/.well-known/openid_configuration":
config := map[string]interface{}{
"jwks_uri": "maaServer.URL" + "/certs",
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(config); err != nil {
t.Fatalf("Failed to encode OpenID configuration: %v", err)
}
case "/certs":
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kid": "test-kid",
"kty": "RSA",
"use": "sig",
"n": "test-n-value",
"e": "AQAB",
},
},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(jwks); err != nil {
t.Fatalf("Failed to encode JWKS: %v", err)
}
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer maaServer.Close()
originalURL := MaaURL
MaaURL = maaServer.URL
defer func() { MaaURL = originalURL }()
provider := NewProvider()
verifier := NewVerifier(&bytes.Buffer{})
teeNonce := []byte("test-tee-nonce-1234567890123456789012")
vtpmNonce := []byte("test-vtpm-nonce-123456789012345678901")
teeReport, err := provider.TeeAttestation(teeNonce)
if err != nil {
t.Logf("TEE attestation failed (expected in mock environment): %v", err)
}
vtpmReport, err := provider.VTpmAttestation(vtpmNonce)
if err != nil {
t.Logf("vTPM attestation failed (expected in mock environment): %v", err)
}
token, err := provider.AzureAttestationToken(teeNonce)
if err != nil {
t.Logf("Azure attestation token failed (expected in mock environment): %v", err)
}
assert.NotNil(t, provider)
assert.NotNil(t, verifier)
t.Logf("TEE report length: %d", len(teeReport))
t.Logf("vTPM report length: %d", len(vtpmReport))
t.Logf("Token length: %d", len(token))
})
}
func TestIntegration_ErrorPropagation(t *testing.T) {
t.Run("error propagation through full stack", func(t *testing.T) {
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
if _, err := w.Write([]byte("Internal Server Error")); err != nil {
t.Fatalf("Failed to write response: %v", err)
}
}))
defer failingServer.Close()
originalURL := MaaURL
MaaURL = failingServer.URL
defer func() { MaaURL = originalURL }()
provider := NewProvider()
_, err := provider.AzureAttestationToken([]byte("test-nonce"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to fetch Azure token")
_, err = GenerateAttestationPolicy("invalid-token", "test-product", 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to validate token")
})
}
func createMockJWT() string {
claims := jwt.MapClaims{
"iss": "https://test-issuer.com",
"aud": "test-audience",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"x-ms-isolation-tee": map[string]interface{}{
"x-ms-sevsnpvm-familyId": "1234567890abcdef",
"x-ms-sevsnpvm-imageId": "fedcba0987654321",
"x-ms-sevsnpvm-launchmeasurement": "abcdef1234567890",
"x-ms-sevsnpvm-bootloader-svn": float64(1),
"x-ms-sevsnpvm-tee-svn": float64(2),
"x-ms-sevsnpvm-snpfw-svn": float64(3),
"x-ms-sevsnpvm-microcode-svn": float64(4),
"x-ms-sevsnpvm-guestsvn": float64(5),
"x-ms-sevsnpvm-idkeydigest": "1234567890abcdef",
"x-ms-sevsnpvm-reportid": "fedcba0987654321",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["jku"] = "https://test-url.com"
token.Header["kid"] = "test-kid"
// Return unsigned token for testing
return token.Raw
}
@@ -0,0 +1,138 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mocks
import mock "github.com/stretchr/testify/mock"
// MeasurementProvider is an autogenerated mock type for the MeasurementProvider type
type MeasurementProvider struct {
mock.Mock
}
type MeasurementProvider_Expecter struct {
mock *mock.Mock
}
func (_m *MeasurementProvider) EXPECT() *MeasurementProvider_Expecter {
return &MeasurementProvider_Expecter{mock: &_m.Mock}
}
// Run provides a mock function with given fields: binaryPath
func (_m *MeasurementProvider) Run(binaryPath string) ([]byte, error) {
ret := _m.Called(binaryPath)
if len(ret) == 0 {
panic("no return value specified for Run")
}
var r0 []byte
var r1 error
if rf, ok := ret.Get(0).(func(string) ([]byte, error)); ok {
return rf(binaryPath)
}
if rf, ok := ret.Get(0).(func(string) []byte); ok {
r0 = rf(binaryPath)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(binaryPath)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MeasurementProvider_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
type MeasurementProvider_Run_Call struct {
*mock.Call
}
// Run is a helper method to define mock.On call
// - binaryPath string
func (_e *MeasurementProvider_Expecter) Run(binaryPath interface{}) *MeasurementProvider_Run_Call {
return &MeasurementProvider_Run_Call{Call: _e.mock.On("Run", binaryPath)}
}
func (_c *MeasurementProvider_Run_Call) Run(run func(binaryPath string)) *MeasurementProvider_Run_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MeasurementProvider_Run_Call) Return(_a0 []byte, _a1 error) *MeasurementProvider_Run_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MeasurementProvider_Run_Call) RunAndReturn(run func(string) ([]byte, error)) *MeasurementProvider_Run_Call {
_c.Call.Return(run)
return _c
}
// Stop provides a mock function with no fields
func (_m *MeasurementProvider) Stop() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Stop")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// MeasurementProvider_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
type MeasurementProvider_Stop_Call struct {
*mock.Call
}
// Stop is a helper method to define mock.On call
func (_e *MeasurementProvider_Expecter) Stop() *MeasurementProvider_Stop_Call {
return &MeasurementProvider_Stop_Call{Call: _e.mock.On("Stop")}
}
func (_c *MeasurementProvider_Stop_Call) Run(run func()) *MeasurementProvider_Stop_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MeasurementProvider_Stop_Call) Return(_a0 error) *MeasurementProvider_Stop_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MeasurementProvider_Stop_Call) RunAndReturn(run func() error) *MeasurementProvider_Stop_Call {
_c.Call.Return(run)
return _c
}
// NewMeasurementProvider creates a new instance of MeasurementProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMeasurementProvider(t interface {
mock.TestingT
Cleanup(func())
}) *MeasurementProvider {
mock := &MeasurementProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+261
View File
@@ -0,0 +1,261 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package gcp
import (
"context"
"errors"
"testing"
"cloud.google.com/go/storage"
"github.com/google/gce-tcb-verifier/proto/endorsement"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestExtract384BitMeasurement(t *testing.T) {
tests := []struct {
name string
attestation *sevsnp.Attestation
setupMock func()
expected string
expectError bool
errorMsg string
}{
{
name: "nil attestation",
attestation: nil,
expectError: true,
errorMsg: "report is nil",
},
{
name: "short report",
attestation: &sevsnp.Attestation{Report: &sevsnp.Report{}},
expectError: true,
errorMsg: "failed to transform report to binary",
},
{
name: "empty report",
attestation: &sevsnp.Attestation{},
expectError: true,
errorMsg: "failed to transform report to binary",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := Extract384BitMeasurement(tt.attestation)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
assert.Empty(t, result)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestGetLaunchEndorsement(t *testing.T) {
tests := []struct {
name string
measurement384 string
setupMock func() ([]byte, error)
expectError bool
errorMsg string
}{
{
name: "successful retrieval",
measurement384: "test-measurement",
setupMock: func() ([]byte, error) {
goldenUEFI := &endorsement.VMGoldenMeasurement{
SevSnp: &endorsement.VMSevSnp{
Policy: 12345,
Measurements: map[uint32][]byte{1: []byte("test-measurement")},
},
}
goldenBytes, _ := proto.Marshal(goldenUEFI)
launchEndorsement := &endorsement.VMLaunchEndorsement{
SerializedUefiGolden: goldenBytes,
}
return proto.Marshal(launchEndorsement)
},
expectError: false,
},
{
name: "storage client error",
measurement384: "test-measurement",
setupMock: func() ([]byte, error) {
return nil, errors.New("storage client error")
},
expectError: true,
errorMsg: "failed to create reader",
},
{
name: "object not found",
measurement384: "non-existent-measurement",
setupMock: func() ([]byte, error) {
return nil, storage.ErrObjectNotExist
},
expectError: true,
errorMsg: "failed to create reader",
},
{
name: "invalid protobuf data",
measurement384: "test-measurement",
setupMock: func() ([]byte, error) {
return []byte("invalid protobuf data"), nil
},
expectError: true,
errorMsg: "failed to create reader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
// skip if credentials are not set
if _, err := storage.NewClient(ctx); err != nil && tt.expectError {
t.Skip("Skipping test due to missing GCP credentials")
}
_, err := GetLaunchEndorsement(ctx, tt.measurement384)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
}
})
}
}
func TestGenerateAttestationPolicy(t *testing.T) {
tests := []struct {
name string
endorsement *endorsement.VMGoldenMeasurement
vcpuNum uint32
expectError bool
errorMsg string
}{
{
name: "valid endorsement",
endorsement: &endorsement.VMGoldenMeasurement{
SevSnp: &endorsement.VMSevSnp{
Policy: 12345,
Measurements: map[uint32][]byte{1: []byte("test-measurement")},
},
},
vcpuNum: 1,
expectError: false,
},
{
name: "missing measurement for vcpu",
endorsement: &endorsement.VMGoldenMeasurement{
SevSnp: &endorsement.VMSevSnp{
Policy: 12345,
Measurements: map[uint32][]byte{2: []byte("test-measurement")},
},
},
vcpuNum: 1,
expectError: false,
},
{
name: "empty measurements map",
endorsement: &endorsement.VMGoldenMeasurement{
SevSnp: &endorsement.VMSevSnp{
Policy: 12345,
Measurements: map[uint32][]byte{},
},
},
vcpuNum: 1,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := GenerateAttestationPolicy(tt.endorsement, tt.vcpuNum)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.NotNil(t, result.Config)
assert.NotNil(t, result.Config.Policy)
assert.NotNil(t, result.Config.RootOfTrust)
assert.NotNil(t, result.PcrConfig)
assert.Equal(t, tt.endorsement.SevSnp.Policy, result.Config.Policy.Policy)
assert.Equal(t, tt.endorsement.SevSnp.Measurements[tt.vcpuNum], result.Config.Policy.Measurement)
assert.False(t, result.Config.RootOfTrust.DisallowNetwork)
assert.True(t, result.Config.RootOfTrust.CheckCrl)
assert.Equal(t, "Milan", result.Config.RootOfTrust.Product)
assert.Equal(t, "Milan", result.Config.RootOfTrust.ProductLine)
}
})
}
}
func TestDownloadOvmfFile(t *testing.T) {
tests := []struct {
name string
digest string
expectError bool
errorMsg string
}{
{
name: "successful download",
digest: "test-digest",
expectError: false,
},
{
name: "storage client error",
digest: "test-digest",
expectError: true,
errorMsg: "failed to create reader",
},
{
name: "object not found",
digest: "non-existent-digest",
expectError: true,
errorMsg: "failed to create reader",
},
{
name: "read error",
digest: "test-digest",
expectError: true,
errorMsg: "failed to create reader",
},
{
name: "empty digest",
digest: "",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
// skip if credentials are not set
if _, err := storage.NewClient(ctx); err != nil && tt.expectError {
t.Skip("Skipping test due to missing GCP credentials")
}
_, err := DownloadOvmfFile(ctx, tt.digest)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
}
})
}
}
-142
View File
@@ -253,148 +253,6 @@ func (_c *Provider_VTpmAttestation_Call) RunAndReturn(run func([]byte) ([]byte,
return _c
}
// VerifTeeAttestation provides a mock function with given fields: report, teeNonce
func (_m *Provider) VerifTeeAttestation(report []byte, teeNonce []byte) error {
ret := _m.Called(report, teeNonce)
if len(ret) == 0 {
panic("no return value specified for VerifTeeAttestation")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
r0 = rf(report, teeNonce)
} else {
r0 = ret.Error(0)
}
return r0
}
// Provider_VerifTeeAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifTeeAttestation'
type Provider_VerifTeeAttestation_Call struct {
*mock.Call
}
// VerifTeeAttestation is a helper method to define mock.On call
// - report []byte
// - teeNonce []byte
func (_e *Provider_Expecter) VerifTeeAttestation(report interface{}, teeNonce interface{}) *Provider_VerifTeeAttestation_Call {
return &Provider_VerifTeeAttestation_Call{Call: _e.mock.On("VerifTeeAttestation", report, teeNonce)}
}
func (_c *Provider_VerifTeeAttestation_Call) Run(run func(report []byte, teeNonce []byte)) *Provider_VerifTeeAttestation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte), args[1].([]byte))
})
return _c
}
func (_c *Provider_VerifTeeAttestation_Call) Return(_a0 error) *Provider_VerifTeeAttestation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Provider_VerifTeeAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Provider_VerifTeeAttestation_Call {
_c.Call.Return(run)
return _c
}
// VerifVTpmAttestation provides a mock function with given fields: report, vTpmNonce
func (_m *Provider) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
ret := _m.Called(report, vTpmNonce)
if len(ret) == 0 {
panic("no return value specified for VerifVTpmAttestation")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
r0 = rf(report, vTpmNonce)
} else {
r0 = ret.Error(0)
}
return r0
}
// Provider_VerifVTpmAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifVTpmAttestation'
type Provider_VerifVTpmAttestation_Call struct {
*mock.Call
}
// VerifVTpmAttestation is a helper method to define mock.On call
// - report []byte
// - vTpmNonce []byte
func (_e *Provider_Expecter) VerifVTpmAttestation(report interface{}, vTpmNonce interface{}) *Provider_VerifVTpmAttestation_Call {
return &Provider_VerifVTpmAttestation_Call{Call: _e.mock.On("VerifVTpmAttestation", report, vTpmNonce)}
}
func (_c *Provider_VerifVTpmAttestation_Call) Run(run func(report []byte, vTpmNonce []byte)) *Provider_VerifVTpmAttestation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte), args[1].([]byte))
})
return _c
}
func (_c *Provider_VerifVTpmAttestation_Call) Return(_a0 error) *Provider_VerifVTpmAttestation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Provider_VerifVTpmAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Provider_VerifVTpmAttestation_Call {
_c.Call.Return(run)
return _c
}
// VerifyAttestation provides a mock function with given fields: report, teeNonce, vTpmNonce
func (_m *Provider) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
ret := _m.Called(report, teeNonce, vTpmNonce)
if len(ret) == 0 {
panic("no return value specified for VerifyAttestation")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, []byte, []byte) error); ok {
r0 = rf(report, teeNonce, vTpmNonce)
} else {
r0 = ret.Error(0)
}
return r0
}
// Provider_VerifyAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAttestation'
type Provider_VerifyAttestation_Call struct {
*mock.Call
}
// VerifyAttestation is a helper method to define mock.On call
// - report []byte
// - teeNonce []byte
// - vTpmNonce []byte
func (_e *Provider_Expecter) VerifyAttestation(report interface{}, teeNonce interface{}, vTpmNonce interface{}) *Provider_VerifyAttestation_Call {
return &Provider_VerifyAttestation_Call{Call: _e.mock.On("VerifyAttestation", report, teeNonce, vTpmNonce)}
}
func (_c *Provider_VerifyAttestation_Call) Run(run func(report []byte, teeNonce []byte, vTpmNonce []byte)) *Provider_VerifyAttestation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte), args[1].([]byte), args[2].([]byte))
})
return _c
}
func (_c *Provider_VerifyAttestation_Call) Return(_a0 error) *Provider_VerifyAttestation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Provider_VerifyAttestation_Call) RunAndReturn(run func([]byte, []byte, []byte) error) *Provider_VerifyAttestation_Call {
_c.Call.Return(run)
return _c
}
// NewProvider creates a new instance of Provider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewProvider(t interface {
+223
View File
@@ -0,0 +1,223 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mocks
import mock "github.com/stretchr/testify/mock"
// Verifier is an autogenerated mock type for the Verifier type
type Verifier struct {
mock.Mock
}
type Verifier_Expecter struct {
mock *mock.Mock
}
func (_m *Verifier) EXPECT() *Verifier_Expecter {
return &Verifier_Expecter{mock: &_m.Mock}
}
// JSONToPolicy provides a mock function with given fields: path
func (_m *Verifier) JSONToPolicy(path string) error {
ret := _m.Called(path)
if len(ret) == 0 {
panic("no return value specified for JSONToPolicy")
}
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(path)
} else {
r0 = ret.Error(0)
}
return r0
}
// Verifier_JSONToPolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'JSONToPolicy'
type Verifier_JSONToPolicy_Call struct {
*mock.Call
}
// JSONToPolicy is a helper method to define mock.On call
// - path string
func (_e *Verifier_Expecter) JSONToPolicy(path interface{}) *Verifier_JSONToPolicy_Call {
return &Verifier_JSONToPolicy_Call{Call: _e.mock.On("JSONToPolicy", path)}
}
func (_c *Verifier_JSONToPolicy_Call) Run(run func(path string)) *Verifier_JSONToPolicy_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *Verifier_JSONToPolicy_Call) Return(_a0 error) *Verifier_JSONToPolicy_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Verifier_JSONToPolicy_Call) RunAndReturn(run func(string) error) *Verifier_JSONToPolicy_Call {
_c.Call.Return(run)
return _c
}
// VerifTeeAttestation provides a mock function with given fields: report, teeNonce
func (_m *Verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
ret := _m.Called(report, teeNonce)
if len(ret) == 0 {
panic("no return value specified for VerifTeeAttestation")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
r0 = rf(report, teeNonce)
} else {
r0 = ret.Error(0)
}
return r0
}
// Verifier_VerifTeeAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifTeeAttestation'
type Verifier_VerifTeeAttestation_Call struct {
*mock.Call
}
// VerifTeeAttestation is a helper method to define mock.On call
// - report []byte
// - teeNonce []byte
func (_e *Verifier_Expecter) VerifTeeAttestation(report interface{}, teeNonce interface{}) *Verifier_VerifTeeAttestation_Call {
return &Verifier_VerifTeeAttestation_Call{Call: _e.mock.On("VerifTeeAttestation", report, teeNonce)}
}
func (_c *Verifier_VerifTeeAttestation_Call) Run(run func(report []byte, teeNonce []byte)) *Verifier_VerifTeeAttestation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte), args[1].([]byte))
})
return _c
}
func (_c *Verifier_VerifTeeAttestation_Call) Return(_a0 error) *Verifier_VerifTeeAttestation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Verifier_VerifTeeAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Verifier_VerifTeeAttestation_Call {
_c.Call.Return(run)
return _c
}
// VerifVTpmAttestation provides a mock function with given fields: report, vTpmNonce
func (_m *Verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
ret := _m.Called(report, vTpmNonce)
if len(ret) == 0 {
panic("no return value specified for VerifVTpmAttestation")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
r0 = rf(report, vTpmNonce)
} else {
r0 = ret.Error(0)
}
return r0
}
// Verifier_VerifVTpmAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifVTpmAttestation'
type Verifier_VerifVTpmAttestation_Call struct {
*mock.Call
}
// VerifVTpmAttestation is a helper method to define mock.On call
// - report []byte
// - vTpmNonce []byte
func (_e *Verifier_Expecter) VerifVTpmAttestation(report interface{}, vTpmNonce interface{}) *Verifier_VerifVTpmAttestation_Call {
return &Verifier_VerifVTpmAttestation_Call{Call: _e.mock.On("VerifVTpmAttestation", report, vTpmNonce)}
}
func (_c *Verifier_VerifVTpmAttestation_Call) Run(run func(report []byte, vTpmNonce []byte)) *Verifier_VerifVTpmAttestation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte), args[1].([]byte))
})
return _c
}
func (_c *Verifier_VerifVTpmAttestation_Call) Return(_a0 error) *Verifier_VerifVTpmAttestation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Verifier_VerifVTpmAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Verifier_VerifVTpmAttestation_Call {
_c.Call.Return(run)
return _c
}
// VerifyAttestation provides a mock function with given fields: report, teeNonce, vTpmNonce
func (_m *Verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
ret := _m.Called(report, teeNonce, vTpmNonce)
if len(ret) == 0 {
panic("no return value specified for VerifyAttestation")
}
var r0 error
if rf, ok := ret.Get(0).(func([]byte, []byte, []byte) error); ok {
r0 = rf(report, teeNonce, vTpmNonce)
} else {
r0 = ret.Error(0)
}
return r0
}
// Verifier_VerifyAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAttestation'
type Verifier_VerifyAttestation_Call struct {
*mock.Call
}
// VerifyAttestation is a helper method to define mock.On call
// - report []byte
// - teeNonce []byte
// - vTpmNonce []byte
func (_e *Verifier_Expecter) VerifyAttestation(report interface{}, teeNonce interface{}, vTpmNonce interface{}) *Verifier_VerifyAttestation_Call {
return &Verifier_VerifyAttestation_Call{Call: _e.mock.On("VerifyAttestation", report, teeNonce, vTpmNonce)}
}
func (_c *Verifier_VerifyAttestation_Call) Run(run func(report []byte, teeNonce []byte, vTpmNonce []byte)) *Verifier_VerifyAttestation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte), args[1].([]byte), args[2].([]byte))
})
return _c
}
func (_c *Verifier_VerifyAttestation_Call) Return(_a0 error) *Verifier_VerifyAttestation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Verifier_VerifyAttestation_Call) RunAndReturn(run func([]byte, []byte, []byte) error) *Verifier_VerifyAttestation_Call {
_c.Call.Return(run)
return _c
}
// NewVerifier creates a new instance of Verifier. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewVerifier(t interface {
mock.TestingT
Cleanup(func())
}) *Verifier {
mock := &Verifier{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+326 -17
View File
@@ -7,11 +7,10 @@
package quoteprovider
import (
"fmt"
"os"
"path"
"testing"
"github.com/absmach/magistrala/pkg/errors"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
@@ -19,51 +18,361 @@ import (
)
func TestFillInAttestationLocal(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
os.Setenv("HOME", originalHome)
}()
tempDir, err := os.MkdirTemp("", "test_home")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
cocosDir := tempDir + "/.cocos/Milan"
os.Setenv("HOME", tempDir)
cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan)
err = os.MkdirAll(cocosDir, 0o755)
require.NoError(t, err)
bundleContent := []byte("mock ASK ARK bundle")
err = os.WriteFile(cocosDir+"/ask_ark.pem", bundleContent, 0o644)
bundlePath := path.Join(cocosDir, caBundleName)
err = os.WriteFile(bundlePath, bundleContent, 0o644)
require.NoError(t, err)
config := check.Config{
RootOfTrust: &check.RootOfTrust{},
Policy: &check.Policy{},
config := &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{},
}
tests := []struct {
name string
attestation *sevsnp.Attestation
err error
name string
attestation *sevsnp.Attestation
setupFunc func()
expectedError bool
errorContains string
}{
{
name: "Empty attestation",
name: "Empty attestation - creates new chain",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
CertificateChain: nil,
},
err: nil,
setupFunc: func() {},
expectedError: true,
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
},
{
name: "Attestation with existing chain",
name: "Attestation with existing chain - no changes needed",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("existing ASK cert"),
ArkCert: []byte("existing ARK cert"),
},
},
err: nil,
setupFunc: func() {},
expectedError: false,
},
{
name: "Attestation with empty chain - tries to load from file",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
setupFunc: func() {},
expectedError: true,
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
},
{
name: "No bundle file exists - no error",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
setupFunc: func() {
os.Remove(bundlePath)
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := fillInAttestationLocal(tt.attestation, &config)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
os.Setenv("HOME", tempDir)
if _, err := os.Stat(bundlePath); os.IsNotExist(err) {
if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil {
t.Fatalf("Failed to write bundle file: %v", err)
}
}
tt.setupFunc()
err := fillInAttestationLocal(tt.attestation, config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestGetProductName(t *testing.T) {
tests := []struct {
name string
product string
expected sevsnp.SevProduct_SevProductName
}{
{
name: "Milan product",
product: sevSnpProductMilan,
expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
},
{
name: "Genoa product",
product: sevSnpProductGenoa,
expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
},
{
name: "Unknown product",
product: "UnknownProduct",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
{
name: "Empty product",
product: "",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
{
name: "Case sensitive - milan lowercase",
product: "milan",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetProductName(tt.product)
assert.Equal(t, tt.expected, result)
})
}
}
func TestVerifyReport(t *testing.T) {
tests := []struct {
name string
attestation *sevsnp.Attestation
config *check.Config
expectedError bool
errorContains string
}{
{
name: "Invalid product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: "InvalidProduct",
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "product name must be",
},
{
name: "Valid Milan product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "attestation verification failed",
},
{
name: "Valid Genoa product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductGenoa,
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "attestation verification failed",
},
{
name: "Config with existing product policy",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{
Product: &sevsnp.SevProduct{
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
},
},
},
expectedError: true,
errorContains: "attestation verification failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := verifyReport(tt.attestation, tt.config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateReport(t *testing.T) {
tests := []struct {
name string
attestation *sevsnp.Attestation
config *check.Config
expectedError bool
errorContains string
}{
{
name: "Basic validation test",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
Policy: &check.Policy{
Policy: 196608,
},
},
expectedError: true,
errorContains: "attestation validation failed",
},
{
name: "Validation with report data",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
Policy: &check.Policy{
Policy: 196608,
ReportData: []byte("test report datatest report datatest report datatest report data"),
},
},
expectedError: true,
errorContains: "attestation validation failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateReport(tt.attestation, tt.config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestFetchAttestation(t *testing.T) {
tests := []struct {
name string
reportData []byte
vmpl uint
expectedError bool
errorContains string
}{
{
name: "Report data too large",
reportData: make([]byte, Nonce+1),
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Valid report data size",
reportData: make([]byte, 32),
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Maximum valid report data size",
reportData: make([]byte, Nonce),
vmpl: 1,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Empty report data",
reportData: []byte{},
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := FetchAttestation(tt.reportData, tt.vmpl)
if tt.expectedError {
assert.Error(t, err)
assert.Empty(t, result)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
assert.NotEmpty(t, result)
}
})
}
}
func TestGetLeveledQuoteProvider(t *testing.T) {
t.Run("GetLeveledQuoteProvider call", func(t *testing.T) {
provider, err := GetLeveledQuoteProvider()
if err != nil {
assert.Error(t, err)
assert.Nil(t, provider)
} else {
assert.NoError(t, err)
assert.NotNil(t, provider)
}
})
}
+8
View File
@@ -45,6 +45,14 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
}
func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
if teeNonce == nil {
return nil, errors.New("tee nonce is required for TDX attestation")
}
if len(teeNonce) != 64 {
return nil, fmt.Errorf("invalid tee nonce length: expected 64 bytes, got %d bytes", len(teeNonce))
}
quoteprovider, err := client.GetQuoteProvider()
if err != nil {
return nil, errors.Wrap(err, errOpenTDXDevice)
+622
View File
@@ -0,0 +1,622 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package tdx
import (
"os"
"path/filepath"
"testing"
"github.com/google/go-tdx-guest/proto/checkconfig"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"google.golang.org/protobuf/encoding/protojson"
)
func TestNewProvider(t *testing.T) {
tests := []struct {
name string
want attestation.Provider
}{
{
name: "should create new provider successfully",
want: provider{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewProvider()
assert.IsType(t, tt.want, got)
})
}
}
func TestProvider_Attestation(t *testing.T) {
tests := []struct {
name string
teeNonce []byte
vTpmNonce []byte
wantErr bool
errContains string
}{
{
name: "should handle empty nonces",
teeNonce: []byte{},
vTpmNonce: []byte{},
wantErr: true,
errContains: "invalid tee nonce length: expected 64 bytes, got 0 bytes",
},
{
name: "should handle valid nonces",
teeNonce: []byte("test-noncetest-noncetest-noncetest-noncetest-noncetest-noncetest"),
vTpmNonce: []byte("vtpm-nonce"),
wantErr: true,
errContains: "/sys/kernel/config/tsm/report",
},
{
name: "should handle nil nonces",
teeNonce: nil,
vTpmNonce: nil,
wantErr: true,
errContains: "tee nonce is required for TDX attestation",
},
{
name: "should handle large nonce",
teeNonce: make([]byte, 64),
vTpmNonce: make([]byte, 32),
wantErr: true,
errContains: "/sys/kernel/config/tsm/report",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := provider{}
got, err := p.Attestation(tt.teeNonce, tt.vTpmNonce)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
assert.Nil(t, got)
} else {
assert.NoError(t, err)
assert.NotNil(t, got)
}
})
}
}
func TestProvider_TeeAttestation(t *testing.T) {
tests := []struct {
name string
teeNonce []byte
wantErr bool
errContains string
}{
{
name: "should handle empty nonce",
teeNonce: []byte{},
wantErr: true,
errContains: "invalid tee nonce length: expected 64 bytes, got 0 bytes",
},
{
name: "should handle valid nonce",
teeNonce: []byte("test-noncetest-noncetest-noncetest-noncetest-noncetest-noncetest"),
wantErr: true,
errContains: "/sys/kernel/config/tsm/report",
},
{
name: "should handle nil nonce",
teeNonce: nil,
wantErr: true,
errContains: "tee nonce is required for TDX attestation",
},
{
name: "should handle 64-byte nonce",
teeNonce: make([]byte, 64),
wantErr: true,
errContains: "/sys/kernel/config/tsm/report",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := provider{}
got, err := p.TeeAttestation(tt.teeNonce)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
assert.Nil(t, got)
} else {
assert.NoError(t, err)
assert.NotNil(t, got)
}
})
}
}
func TestProvider_VTpmAttestation(t *testing.T) {
tests := []struct {
name string
vTpmNonce []byte
wantErr bool
errContains string
}{
{
name: "should return error for empty nonce",
vTpmNonce: []byte{},
wantErr: true,
errContains: "vTPM attestation fetch is not supported",
},
{
name: "should return error for valid nonce",
vTpmNonce: []byte("vtpm-nonce"),
wantErr: true,
errContains: "vTPM attestation fetch is not supported",
},
{
name: "should return error for nil nonce",
vTpmNonce: nil,
wantErr: true,
errContains: "vTPM attestation fetch is not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := provider{}
got, err := p.VTpmAttestation(tt.vTpmNonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
assert.Nil(t, got)
})
}
}
func TestProvider_AzureAttestationToken(t *testing.T) {
tests := []struct {
name string
tokenNonce []byte
wantErr bool
errContains string
}{
{
name: "should return error for empty nonce",
tokenNonce: []byte{},
wantErr: true,
errContains: "Azure attestation token is not supported",
},
{
name: "should return error for valid nonce",
tokenNonce: []byte("token-nonce"),
wantErr: true,
errContains: "Azure attestation token is not supported",
},
{
name: "should return error for nil nonce",
tokenNonce: nil,
wantErr: true,
errContains: "Azure attestation token is not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := provider{}
got, err := p.AzureAttestationToken(tt.tokenNonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
assert.Nil(t, got)
})
}
}
func TestNewVerifier(t *testing.T) {
tests := []struct {
name string
}{
{
name: "should create new verifier successfully",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewVerifier()
v, ok := got.(verifier)
assert.True(t, ok)
assert.NotNil(t, v.Policy)
assert.NotNil(t, v.Policy.RootOfTrust)
assert.NotNil(t, v.Policy.Policy)
assert.NotNil(t, v.Policy.Policy.HeaderPolicy)
assert.NotNil(t, v.Policy.Policy.TdQuoteBodyPolicy)
})
}
}
func TestNewVerifierWithPolicy(t *testing.T) {
tests := []struct {
name string
policy *checkconfig.Config
}{
{
name: "should create verifier with nil policy",
policy: nil,
},
{
name: "should create verifier with valid policy",
policy: &checkconfig.Config{
RootOfTrust: &checkconfig.RootOfTrust{},
Policy: &checkconfig.Policy{},
},
},
{
name: "should create verifier with empty policy",
policy: &checkconfig.Config{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewVerifierWithPolicy(tt.policy)
v, ok := got.(verifier)
assert.True(t, ok)
if tt.policy == nil {
assert.NotNil(t, v.Policy)
assert.NotNil(t, v.Policy.RootOfTrust)
assert.NotNil(t, v.Policy.Policy)
} else {
assert.Equal(t, tt.policy, v.Policy)
}
})
}
}
func TestVerifier_VerifTeeAttestation(t *testing.T) {
tests := []struct {
name string
verifier verifier
report []byte
teeNonce []byte
wantErr bool
errContains string
}{
{
name: "should return error when policy is nil",
verifier: verifier{
Policy: nil,
},
report: []byte("test-report"),
teeNonce: []byte("test-nonce"),
wantErr: true,
errContains: "tdx policy is not provided",
},
{
name: "should handle invalid report format",
verifier: verifier{
Policy: &checkconfig.Config{
RootOfTrust: &checkconfig.RootOfTrust{},
Policy: &checkconfig.Policy{},
},
},
report: []byte("invalid-report"),
teeNonce: []byte("test-nonce"),
wantErr: true,
errContains: "",
},
{
name: "should handle empty report",
verifier: verifier{
Policy: &checkconfig.Config{
RootOfTrust: &checkconfig.RootOfTrust{},
Policy: &checkconfig.Policy{},
},
},
report: []byte{},
teeNonce: []byte("test-nonce"),
wantErr: true,
errContains: "",
},
{
name: "should handle nil report",
verifier: verifier{
Policy: &checkconfig.Config{
RootOfTrust: &checkconfig.RootOfTrust{},
Policy: &checkconfig.Policy{},
},
},
report: nil,
teeNonce: []byte("test-nonce"),
wantErr: true,
errContains: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.verifier.VerifTeeAttestation(tt.report, tt.teeNonce)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestVerifier_VerifVTpmAttestation(t *testing.T) {
tests := []struct {
name string
verifier verifier
report []byte
vTpmNonce []byte
wantErr bool
errContains string
}{
{
name: "should return error for any input",
verifier: verifier{},
report: []byte("test-report"),
vTpmNonce: []byte("test-nonce"),
wantErr: true,
errContains: "VTPM attestation verification is not supported",
},
{
name: "should return error for empty inputs",
verifier: verifier{},
report: []byte{},
vTpmNonce: []byte{},
wantErr: true,
errContains: "VTPM attestation verification is not supported",
},
{
name: "should return error for nil inputs",
verifier: verifier{},
report: nil,
vTpmNonce: nil,
wantErr: true,
errContains: "VTPM attestation verification is not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.verifier.VerifVTpmAttestation(tt.report, tt.vTpmNonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
})
}
}
func TestVerifier_VerifyAttestation(t *testing.T) {
tests := []struct {
name string
verifier verifier
report []byte
teeNonce []byte
vTpmNonce []byte
wantErr bool
errContains string
}{
{
name: "should delegate to VerifTeeAttestation with nil policy",
verifier: verifier{
Policy: nil,
},
report: []byte("test-report"),
teeNonce: []byte("test-nonce"),
vTpmNonce: []byte("vtpm-nonce"),
wantErr: true,
errContains: "tdx policy is not provided",
},
{
name: "should delegate to VerifTeeAttestation with valid policy",
verifier: verifier{
Policy: &checkconfig.Config{
RootOfTrust: &checkconfig.RootOfTrust{},
Policy: &checkconfig.Policy{},
},
},
report: []byte("invalid-report"),
teeNonce: []byte("test-nonce"),
vTpmNonce: []byte("vtpm-nonce"),
wantErr: true,
errContains: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.verifier.VerifyAttestation(tt.report, tt.teeNonce, tt.vTpmNonce)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestVerifier_JSONToPolicy(t *testing.T) {
tempDir := t.TempDir()
testPolicy := &checkconfig.Config{
RootOfTrust: &checkconfig.RootOfTrust{},
Policy: &checkconfig.Policy{
HeaderPolicy: &checkconfig.HeaderPolicy{},
TdQuoteBodyPolicy: &checkconfig.TDQuoteBodyPolicy{},
},
}
validPolicyJSON, err := protojson.Marshal(testPolicy)
require.NoError(t, err)
validPolicyFile := filepath.Join(tempDir, "valid_policy.json")
err = os.WriteFile(validPolicyFile, validPolicyJSON, 0o644)
require.NoError(t, err)
invalidPolicyFile := filepath.Join(tempDir, "invalid_policy.json")
err = os.WriteFile(invalidPolicyFile, []byte("invalid json"), 0o644)
require.NoError(t, err)
tests := []struct {
name string
verifier verifier
path string
wantErr bool
errContains string
}{
{
name: "should load valid policy file",
verifier: verifier{
Policy: &checkconfig.Config{},
},
path: validPolicyFile,
wantErr: false,
},
{
name: "should return error for non-existent file",
verifier: verifier{
Policy: &checkconfig.Config{},
},
path: filepath.Join(tempDir, "non_existent.json"),
wantErr: true,
errContains: "no such file or directory",
},
{
name: "should return error for invalid JSON",
verifier: verifier{
Policy: &checkconfig.Config{},
},
path: invalidPolicyFile,
wantErr: true,
errContains: "",
},
{
name: "should return error for empty path",
verifier: verifier{
Policy: &checkconfig.Config{},
},
path: "",
wantErr: true,
errContains: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.verifier.JSONToPolicy(tt.path)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestReadTDXAttestationPolicy(t *testing.T) {
tempDir := t.TempDir()
testPolicy := &checkconfig.Config{
RootOfTrust: &checkconfig.RootOfTrust{},
Policy: &checkconfig.Policy{
HeaderPolicy: &checkconfig.HeaderPolicy{},
TdQuoteBodyPolicy: &checkconfig.TDQuoteBodyPolicy{},
},
}
validPolicyJSON, err := protojson.Marshal(testPolicy)
require.NoError(t, err)
validPolicyFile := filepath.Join(tempDir, "valid_policy.json")
err = os.WriteFile(validPolicyFile, validPolicyJSON, 0o644)
require.NoError(t, err)
invalidPolicyFile := filepath.Join(tempDir, "invalid_policy.json")
err = os.WriteFile(invalidPolicyFile, []byte("invalid json"), 0o644)
require.NoError(t, err)
emptyFile := filepath.Join(tempDir, "empty.json")
err = os.WriteFile(emptyFile, []byte{}, 0o644)
require.NoError(t, err)
tests := []struct {
name string
policyPath string
policy *checkconfig.Config
wantErr bool
errContains string
}{
{
name: "should read valid policy file",
policyPath: validPolicyFile,
policy: &checkconfig.Config{},
wantErr: false,
},
{
name: "should return error for non-existent file",
policyPath: filepath.Join(tempDir, "non_existent.json"),
policy: &checkconfig.Config{},
wantErr: true,
errContains: "no such file or directory",
},
{
name: "should return error for invalid JSON",
policyPath: invalidPolicyFile,
policy: &checkconfig.Config{},
wantErr: true,
errContains: "",
},
{
name: "should return error for empty file",
policyPath: emptyFile,
policy: &checkconfig.Config{},
wantErr: true,
errContains: "",
},
{
name: "should return error for empty path",
policyPath: "",
policy: &checkconfig.Config{},
wantErr: true,
errContains: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ReadTDXAttestationPolicy(tt.policyPath, tt.policy)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
assert.NoError(t, err)
assert.NotNil(t, tt.policy)
}
})
}
}
+634 -1
View File
@@ -4,7 +4,11 @@
package vtpm
import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"testing"
@@ -13,6 +17,8 @@ import (
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-tpm-tools/proto/attest"
ptpm "github.com/google/go-tpm-tools/proto/tpm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
@@ -24,6 +30,633 @@ const sevSnpProductMilan = "Milan"
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
type mockTPM struct {
*bytes.Buffer
closeErr error
}
func (m *mockTPM) Close() error {
return m.closeErr
}
type mockWriter struct {
data []byte
err error
}
func (m *mockWriter) Write(p []byte) (n int, err error) {
if m.err != nil {
return 0, m.err
}
m.data = append(m.data, p...)
return len(p), nil
}
func TestOpenTpm(t *testing.T) {
tests := []struct {
name string
externalTPM io.ReadWriteCloser
expectError bool
}{
{
name: "External TPM available",
externalTPM: &mockTPM{Buffer: &bytes.Buffer{}},
expectError: false,
},
{
name: "No external TPM",
externalTPM: nil,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
originalExternalTPM := ExternalTPM
defer func() { ExternalTPM = originalExternalTPM }()
ExternalTPM = tt.externalTPM
tpm, err := OpenTpm()
if tt.expectError {
assert.Error(t, err)
} else {
if tt.externalTPM != nil {
assert.NoError(t, err)
assert.NotNil(t, tpm)
}
}
})
}
}
func TestTpmEventLog(t *testing.T) {
tempFile, err := os.CreateTemp("", "event_log")
require.NoError(t, err)
defer os.Remove(tempFile.Name())
testData := []byte("test event log data")
_, err = tempFile.Write(testData)
require.NoError(t, err)
tempFile.Close()
tpm := &tpm{ReadWriteCloser: &mockTPM{Buffer: &bytes.Buffer{}}}
_, err = tpm.EventLog()
assert.Error(t, err)
}
func TestNewProvider(t *testing.T) {
tests := []struct {
name string
teeAttestation bool
vmpl uint
}{
{
name: "TEE attestation enabled",
teeAttestation: true,
vmpl: 1,
},
{
name: "TEE attestation disabled",
teeAttestation: false,
vmpl: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := NewProvider(tt.teeAttestation, tt.vmpl)
assert.NotNil(t, provider)
})
}
}
func TestProviderAzureAttestationToken(t *testing.T) {
provider := NewProvider(false, 0)
token, err := provider.AzureAttestationToken([]byte("test-nonce"))
assert.Error(t, err)
assert.Nil(t, token)
assert.Contains(t, err.Error(), "Azure attestation token is not supported")
}
func TestNewVerifier(t *testing.T) {
writer := &mockWriter{}
verifier := NewVerifier(writer)
assert.NotNil(t, verifier)
}
func TestNewVerifierWithPolicy(t *testing.T) {
writer := &mockWriter{}
policy := &attestation.Config{
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
PcrConfig: &attestation.PcrConfig{},
}
tests := []struct {
name string
policy *attestation.Config
}{
{
name: "With policy",
policy: policy,
},
{
name: "Without policy (nil)",
policy: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
verifier := NewVerifierWithPolicy([]byte("test-key"), writer, tt.policy)
assert.NotNil(t, verifier)
})
}
}
func TestMarshalQuote(t *testing.T) {
tests := []struct {
name string
attestation *attest.Attestation
expectError bool
}{
{
name: "Valid attestation",
attestation: &attest.Attestation{
AkPub: []byte("test-key"),
},
expectError: false,
},
{
name: "Nil attestation",
attestation: nil,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := marshalQuote(tt.attestation)
if tt.expectError {
assert.Error(t, err)
assert.Empty(t, data)
} else {
assert.NoError(t, err)
if tt.attestation != nil {
assert.NotEmpty(t, data)
}
}
})
}
}
func TestCheckExpectedPCRValues(t *testing.T) {
testPCRValue := make([]byte, 32)
for i := range testPCRValue {
testPCRValue[i] = byte(i)
}
tests := []struct {
name string
attestation *attest.Attestation
policy *attestation.Config
expectError bool
errorMsg string
}{
{
name: "Matching PCR values SHA256",
attestation: &attest.Attestation{
Quotes: []*ptpm.Quote{
{
Pcrs: &ptpm.PCRs{
Hash: ptpm.HashAlgo_SHA256,
Pcrs: map[uint32][]byte{
0: testPCRValue,
},
},
},
},
},
policy: &attestation.Config{
PcrConfig: &attestation.PcrConfig{
PCRValues: attestation.PcrValues{
Sha256: map[string]string{
"0": hex.EncodeToString(testPCRValue),
},
},
},
},
expectError: false,
},
{
name: "Mismatched PCR values",
attestation: &attest.Attestation{
Quotes: []*ptpm.Quote{
{
Pcrs: &ptpm.PCRs{
Hash: ptpm.HashAlgo_SHA256,
Pcrs: map[uint32][]byte{
0: testPCRValue,
},
},
},
},
},
policy: &attestation.Config{
PcrConfig: &attestation.PcrConfig{
PCRValues: attestation.PcrValues{
Sha256: map[string]string{
"0": hex.EncodeToString(make([]byte, 32)),
},
},
},
},
expectError: true,
errorMsg: "expected",
},
{
name: "Unsupported hash algorithm",
attestation: &attest.Attestation{
Quotes: []*ptpm.Quote{
{
Pcrs: &ptpm.PCRs{
Hash: ptpm.HashAlgo_HASH_INVALID,
Pcrs: map[uint32][]byte{
0: testPCRValue,
},
},
},
},
},
policy: &attestation.Config{
PcrConfig: &attestation.PcrConfig{},
},
expectError: true,
errorMsg: "hash algo is not supported",
},
{
name: "Invalid PCR index",
attestation: &attest.Attestation{
Quotes: []*ptpm.Quote{
{
Pcrs: &ptpm.PCRs{
Hash: ptpm.HashAlgo_SHA256,
Pcrs: map[uint32][]byte{
0: testPCRValue,
},
},
},
},
},
policy: &attestation.Config{
PcrConfig: &attestation.PcrConfig{
PCRValues: attestation.PcrValues{
Sha256: map[string]string{
"invalid": hex.EncodeToString(testPCRValue),
},
},
},
},
expectError: true,
errorMsg: "error converting PCR index to int32",
},
{
name: "Invalid PCR value hex",
attestation: &attest.Attestation{
Quotes: []*ptpm.Quote{
{
Pcrs: &ptpm.PCRs{
Hash: ptpm.HashAlgo_SHA256,
Pcrs: map[uint32][]byte{
0: testPCRValue,
},
},
},
},
},
policy: &attestation.Config{
PcrConfig: &attestation.PcrConfig{
PCRValues: attestation.PcrValues{
Sha256: map[string]string{
"0": "invalid-hex",
},
},
},
},
expectError: true,
errorMsg: "error converting PCR value to byte",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkExpectedPCRValues(tt.attestation, tt.policy)
if tt.expectError {
assert.Error(t, err)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestReadPolicy(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
validPolicy := map[string]interface{}{
"policy": map[string]interface{}{
"product": map[string]interface{}{
"name": "test-product",
},
},
"rootOfTrust": map[string]interface{}{
"productLine": "test-line",
},
"pcrConfig": map[string]interface{}{
"pcrValues": map[string]interface{}{
"sha256": map[string]string{
"0": "0000000000000000000000000000000000000000000000000000000000000000",
},
},
},
}
validPolicyData, err := json.Marshal(validPolicy)
require.NoError(t, err)
validPolicyPath := filepath.Join(tempDir, "valid_policy.json")
err = os.WriteFile(validPolicyPath, validPolicyData, 0o644)
require.NoError(t, err)
tests := []struct {
name string
policyPath string
expectError bool
expectedErr error
}{
{
name: "Valid policy file",
policyPath: validPolicyPath,
expectError: false,
},
{
name: "Non-existent policy file",
policyPath: "/nonexistent/path",
expectError: true,
expectedErr: ErrAttestationPolicyOpen,
},
{
name: "Empty policy path",
policyPath: "",
expectError: true,
expectedErr: ErrAttestationPolicyMissing,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &attestation.Config{
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
PcrConfig: &attestation.PcrConfig{},
}
err := ReadPolicy(tt.policyPath, config)
if tt.expectError {
assert.Error(t, err)
if tt.expectedErr != nil {
assert.True(t, errors.Contains(err, tt.expectedErr))
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestReadPolicyFromByte(t *testing.T) {
tests := []struct {
name string
policyData []byte
expectError bool
expectedErr error
}{
{
name: "Valid policy data",
policyData: []byte(`{
"policy": {
"product": {
"name": "test-product"
}
},
"rootOfTrust": {
"productLine": "test-line"
},
"pcrConfig": {
"pcrValues": {
"sha256": {
"0": "0000000000000000000000000000000000000000000000000000000000000000"
}
}
}
}`),
expectError: false,
},
{
name: "Invalid JSON",
policyData: []byte(`{invalid json`),
expectError: true,
expectedErr: ErrAttestationPolicyDecode,
},
{
name: "Empty policy data",
policyData: []byte(``),
expectError: true,
expectedErr: ErrAttestationPolicyDecode,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &attestation.Config{
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
PcrConfig: &attestation.PcrConfig{},
}
err := ReadPolicyFromByte(tt.policyData, config)
if tt.expectError {
assert.Error(t, err)
if tt.expectedErr != nil {
assert.True(t, errors.Contains(err, tt.expectedErr))
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestConvertPolicyToJSON(t *testing.T) {
tests := []struct {
name string
config *attestation.Config
expectError bool
expectedErr error
}{
{
name: "Valid config",
config: &attestation.Config{
Config: &check.Config{
Policy: &check.Policy{
Product: &sevsnp.SevProduct{
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
},
},
RootOfTrust: &check.RootOfTrust{
ProductLine: "Milan",
},
},
PcrConfig: &attestation.PcrConfig{
PCRValues: attestation.PcrValues{
Sha256: map[string]string{
"0": "0000000000000000000000000000000000000000000000000000000000000000",
},
},
},
},
expectError: false,
},
{
name: "Nil config",
config: &attestation.Config{
Config: nil,
PcrConfig: &attestation.PcrConfig{},
},
expectError: false,
expectedErr: ErrProtoMarshalFailed,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jsonData, err := ConvertPolicyToJSON(tt.config)
if tt.expectError {
assert.Error(t, err)
if tt.expectedErr != nil {
assert.True(t, errors.Contains(err, tt.expectedErr))
}
assert.Nil(t, jsonData)
} else {
assert.NoError(t, err)
assert.NotNil(t, jsonData)
var result map[string]interface{}
err = json.Unmarshal(jsonData, &result)
assert.NoError(t, err)
}
})
}
}
func TestVTPMVerify(t *testing.T) {
writer := &mockWriter{}
policy := &attestation.Config{
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
PcrConfig: &attestation.PcrConfig{},
}
tests := []struct {
name string
quote []byte
teeNonce []byte
vtpmNonce []byte
expectError bool
}{
{
name: "Invalid quote data",
quote: []byte("invalid"),
teeNonce: []byte("tee-nonce"),
vtpmNonce: []byte("vtpm-nonce"),
expectError: true,
},
{
name: "Empty quote",
quote: []byte{},
teeNonce: []byte("tee-nonce"),
vtpmNonce: []byte("vtpm-nonce"),
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := VTPMVerify(tt.quote, tt.teeNonce, tt.vtpmNonce, writer, policy)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestVerifyQuote(t *testing.T) {
writer := &mockWriter{}
policy := &attestation.Config{
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
PcrConfig: &attestation.PcrConfig{},
}
tests := []struct {
name string
quote []byte
vtpmNonce []byte
expectError bool
}{
{
name: "Invalid quote data",
quote: []byte("invalid"),
vtpmNonce: []byte("vtpm-nonce"),
expectError: true,
},
{
name: "Empty quote",
quote: []byte{},
vtpmNonce: []byte("vtpm-nonce"),
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := VerifyQuote(tt.quote, tt.vtpmNonce, writer, policy)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestWriterError(t *testing.T) {
writer := &mockWriter{err: fmt.Errorf("write error")}
policy := &attestation.Config{
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
PcrConfig: &attestation.PcrConfig{},
}
err := VerifyQuote([]byte("invalid"), []byte("nonce"), writer, policy)
assert.Error(t, err)
}
func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
@@ -33,7 +666,7 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
// Change random data so in the signature so the signature failes
// Change random data so in the signature so the signature fails
attestationPB.Report.Signature[0] = attestationPB.Report.Signature[0] ^ 0x01
tests := []struct {
+128
View File
@@ -0,0 +1,128 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cvm
import (
"fmt"
"net"
"testing"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/mocks"
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
type TestServer struct {
agent.UnimplementedAgentServiceServer
server *grpc.Server
health *health.Server
port int
listenAddr string
}
func NewTestServer() (*TestServer, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("failed to listen: %v", err)
}
addr := listener.Addr().(*net.TCPAddr)
server := grpc.NewServer()
healthServer := health.NewServer()
ts := &TestServer{
server: server,
health: healthServer,
port: addr.Port,
listenAddr: fmt.Sprintf("localhost:%d", addr.Port),
}
svc := new(mocks.Service)
agent.RegisterAgentServiceServer(server, agentgrpc.NewServer(svc))
grpchealth.RegisterHealthServer(server, healthServer)
go func() {
if err := server.Serve(listener); err != nil {
fmt.Printf("Server exited with error: %v\n", err)
}
}()
healthServer.SetServingStatus("agent", grpchealth.HealthCheckResponse_SERVING)
return ts, nil
}
func (s *TestServer) Stop() {
if s.server != nil {
s.server.GracefulStop()
}
}
func TestAgentClientIntegration(t *testing.T) {
testServer, err := NewTestServer()
require.NoError(t, err)
defer testServer.Stop()
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
serverRunning bool
config pkggrpc.CVMClientConfig
err error
}{
{
name: "successful connection",
serverRunning: true,
config: pkggrpc.CVMClientConfig{
BaseConfig: pkggrpc.BaseConfig{
URL: testServer.listenAddr,
Timeout: 1,
},
},
err: nil,
},
{
name: "server not healthy",
serverRunning: false,
config: pkggrpc.CVMClientConfig{
BaseConfig: pkggrpc.BaseConfig{
URL: "",
Timeout: 1,
},
},
err: errors.New("failed to connect to grpc server"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !tt.serverRunning {
testServer.health.SetServingStatus("agent", grpchealth.HealthCheckResponse_NOT_SERVING)
} else {
testServer.health.SetServingStatus("agent", grpchealth.HealthCheckResponse_SERVING)
}
client, agentClient, err := NewCVMClient(tt.config)
assert.True(t, errors.Contains(err, tt.err))
if err != nil {
assert.Nil(t, client)
assert.Nil(t, agentClient)
return
}
require.NotNil(t, client)
require.NotNil(t, agentClient)
defer client.Close()
})
}
}
+70
View File
@@ -529,6 +529,76 @@ func TestReceiveAttestation(t *testing.T) {
}
}
func TestReceiverIMAMeasurements(t *testing.T) {
tests := []struct {
name string
description string
totalSize int
chunks [][]byte
wantResult []byte
wantErr error
}{
{
name: "successful single chunk receive",
description: "Receiving IMA measurements",
totalSize: 20,
chunks: [][]byte{[]byte("12345678912345678999")},
wantResult: []byte("12345678912345678999"),
wantErr: nil,
},
{
name: "stream error",
description: "Receiving IMA measurements",
totalSize: 20,
chunks: [][]byte{[]byte("12345678912345678999")},
wantResult: nil,
wantErr: errors.New("stream error"),
},
{
name: "size mismatch",
description: "Receiving IMA measurements",
totalSize: 10,
chunks: [][]byte{[]byte("12345678912345678999")},
wantResult: nil,
wantErr: errors.New("progress update exceeds total bytes: attempted to add 20 bytes, but only 10 bytes remain"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStream := new(mocks.AgentService_IMAMeasurementsClient[agent.IMAMeasurementsResponse])
p := New(true)
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
resultFile, err := os.CreateTemp("", "test_ima_measurements")
assert.NoError(t, err)
t.Cleanup(func() {
os.Remove(resultFile.Name())
})
if tt.wantErr != nil {
mockStream.On("Recv").Return(nil, tt.wantErr).Once()
}
mockStream.On("Recv").Return(&agent.IMAMeasurementsResponse{Pcr10: []byte(tt.chunks[0]), File: []byte(tt.chunks[0])}, nil).Once()
mockStream.On("Recv").Return(nil, io.EOF).Once()
pcr10, err := p.ReceiveIMAMeasurements(tt.description, tt.totalSize, mockStream, resultFile)
assert.NoError(t, resultFile.Close())
if tt.wantErr != nil {
assert.Error(t, err)
assert.Equal(t, tt.wantErr.Error(), err.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tt.wantResult, pcr10)
}
})
}
}
type mockAlgoStream struct {
stream agent.AgentService_AlgoClient
sendCount int
+29 -29
View File
@@ -176,6 +176,35 @@ func (sdk *agentSDK) AttestationResult(ctx context.Context, nonce [size32]byte,
return nil
}
func (sdk *agentSDK) IMAMeasurements(ctx context.Context, resultFile *os.File) ([]byte, error) {
request := &agent.IMAMeasurementsRequest{}
stream, err := sdk.client.IMAMeasurements(ctx, request)
if err != nil {
return nil, err
}
incomingmd, err := stream.Header()
if err != nil {
return nil, err
}
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
if len(fileSizeStr) == 0 {
fileSizeStr = append(fileSizeStr, "0")
}
fileSize, err := strconv.Atoi(fileSizeStr[0])
if err != nil {
return nil, err
}
pb := progressbar.New(true)
return pb.ReceiveIMAMeasurements(imaMeasurementsProgressDescription, fileSize, stream, resultFile)
}
func signData(userID string, privKey crypto.Signer) ([]byte, error) {
var signature []byte
var err error
@@ -208,32 +237,3 @@ func generateMetadata(userID string, privateKey crypto.PrivateKey) (metadata.MD,
kv[auth.SignatureMetadataKey] = base64.StdEncoding.EncodeToString(signature)
return metadata.New(kv), nil
}
func (sdk *agentSDK) IMAMeasurements(ctx context.Context, resultFile *os.File) ([]byte, error) {
request := &agent.IMAMeasurementsRequest{}
stream, err := sdk.client.IMAMeasurements(ctx, request)
if err != nil {
return nil, err
}
incomingmd, err := stream.Header()
if err != nil {
return nil, err
}
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
if len(fileSizeStr) == 0 {
fileSizeStr = append(fileSizeStr, "0")
}
fileSize, err := strconv.Atoi(fileSizeStr[0])
if err != nil {
return nil, err
}
pb := progressbar.New(true)
return pb.ReceiveIMAMeasurements(imaMeasurementsProgressDescription, fileSize, stream, resultFile)
}
+76
View File
@@ -558,6 +558,82 @@ func TestAttestationResult(t *testing.T) {
}
}
func TestIMAMeasurements(t *testing.T) {
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
if err != nil {
t.Fatalf("Failed to dial bufnet: %v", err)
}
defer conn.Close()
client := agent.NewAgentServiceClient(conn)
sdk := sdk.NewAgentSDK(client)
response := &agent.IMAMeasurementsResponse{
File: []byte{
0x01, 0x02, 0x03, 0x04,
0x05, 0x06, 0x07, 0x08,
0x01, 0x02, 0x03, 0x04,
0x05, 0x06, 0x07, 0x08,
0x01, 0x02, 0x03, 0x04,
0x05, 0x06, 0x07, 0x08,
},
}
cases := []struct {
name string
response *agent.IMAMeasurementsResponse
svcRes []byte
err error
}{
{
name: "fetch IMA measurements successfully",
response: response,
svcRes: response.File,
err: nil,
},
{
name: "failed to fetch IMA measurements",
response: &agent.IMAMeasurementsResponse{File: []byte{}},
svcRes: nil,
err: errors.New("failed to fetch IMA measurements"),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
svcCall := svc.On("IMAMeasurements", mock.Anything).Return(tc.svcRes, tc.svcRes, tc.err)
file, err := os.CreateTemp("", "ima_measurements")
require.NoError(t, err)
t.Cleanup(func() {
os.Remove(file.Name())
})
_, err = sdk.IMAMeasurements(context.Background(), file)
require.NoError(t, file.Close())
st, ok := status.FromError(err)
if !ok {
t.Fatalf("Expected gRPC status error, but got: %v", err)
}
if tc.err != nil {
if st.Message() != tc.err.Error() {
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
}
}
res, err := os.ReadFile(file.Name())
require.NoError(t, err)
assert.Equal(t, tc.response.File, res, tc.name)
svcCall.Unset()
})
}
}
func generateKeys(t *testing.T, keyType string) (priv any, pub []byte) {
switch keyType {
case "ecdsa":