Files
cocos/pkg/atls/atls_test.go
T
Sammy Kerata Oina da31d76c94
CI / checkproto (push) Has been cancelled
CI / lint (push) Has been cancelled
Rust CI Pipeline / rust-check (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
NOISSUE - Agent Pull mode for remote resources (#575)
* feat(kbs): implement KBS client for attestation and resource retrieval

- Added KBS client implementation in pkg/kbs/client.go with methods for attestation and resource retrieval.
- Introduced necessary data structures for requests and responses.
- Implemented error handling for various scenarios.

test(kbs): add unit tests for KBS client

- Created comprehensive tests for the KBS client in pkg/kbs/client_test.go.
- Included tests for attestation success and failure cases, as well as resource retrieval.

feat(registry): introduce HTTP and S3 registry implementations

- Added HTTPRegistry for downloading resources over HTTP/HTTPS with retry logic in pkg/registry/http.go.
- Implemented S3Registry for downloading resources from AWS S3 and S3-compatible services in pkg/registry/s3.go.
- Included error handling and configuration options for both registries.

chore(registry): define registry interface and configuration

- Created registry interface and configuration struct in pkg/registry/registry.go.
- Added default configuration settings for registry clients.

docs(cvms): update README for CVMS server configuration and usage

- Enhanced documentation for CVMS server with detailed command-line flags and usage examples.
- Clarified direct upload and remote resource modes, including KBS integration.

fix(cvms): integrate KBS for remote resource handling in main.go

- Updated main.go to support remote datasets and algorithms using KBS.
- Added validation for command-line flags to ensure proper configuration.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Move ifeq conditional outside define block in attestation-service.mk

Make conditionals cannot be evaluated inside define...endef blocks
when used as recipe bodies. Restructured to define the
ATTESTATION_SERVICE_INSTALL_INIT_SYSTEMD block conditionally based
on BR2_PACKAGE_CC_ATTESTATION_AGENT configuration.

* feat: Implement remote resource downloading for algorithms and datasets using AWS S3/MinIO credentials.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add comprehensive documentation and agent support for testing remote resource download with KBS attestation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Improve agent logging for remote resource configuration and KBS status, and add a testing guide for remote resource downloads with KBS attestation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add a comprehensive guide for testing remote resource download with KBS attestation and update multiple package versions to a specific commit.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add failure transitions for resource reception states and a comprehensive guide for testing remote resource downloads with KBS attestation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Implement remote resource download with KBS attestation in the agent and add a comprehensive testing guide.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: Add comprehensive guide for testing remote resource download with KBS attestation and include a debug log in the attestation client.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Delegate KBS attestation and token retrieval to a new attestation-agent service and document remote resource testing.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* client fixes

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* raw evidence

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Build all Go files in cmd directories, not just main.go

This fixes the issue where fetch_raw_evidence.go wasn't being included
in the attestation-service build.

* fix: Wrap binary evidence in JSON for KBS compatibility

Fixes 'invalid character' error by wrapping raw binary evidence
in a JSON structure with base64 encoding, as expected by KBS.

* chore: Update buildroot packages to c28cefae

Includes fixes for:
1. attestation-service build (including fetch_raw_evidence.go)
2. Agent KBS evidence format (wrapping binary in JSON)

* fix: Implement KBS RCAR handshake with cookies

Fixes 'cookie not found' error (401) from KBS by:
1. Adding CookieJar support to KBS client
2. Implementing GetChallenge() to perform /auth handshake and capture session cookie
3. Updating Agent to get challenge, decode nonce, and use it for evidence generation
4. Regenerating mocks

* chore: Update buildroot packages to f6981ac5

Includes KBS RCAR handshake fix (cookie support + GetChallenge loop)

* fix: Update KBS client JSON tags to kebab-case

Fixes deserialization error (401) from KBS by:
1. Using kebab-case (e.g. extra-params) for JSON tags as per protocol.
2. Initializing ExtraParams as empty object {} instead of null/omitted.

* fix: Wrap attestation evidence in primary_evidence format

Updates Agent to construct 'tee-evidence' payload with:
- primary_evidence: containing the actual quote/data
- additional_evidence: empty JSON object

This matches the Confidential Containers KBS Attestation Protocol requirements.

* fix: Update KBS protocol version to 0.4.0

KBS rejected 0.1.0 with a version mismatch error. Bumping to 0.4.0 to match server expectation.

* fix: Generate ephemeral key for KBS RuntimeData

Updates RuntimeData to include a valid ephemeral EC P-256 public key in JWK format, as required by the KBS RCAR protocol.
Also fixes the KBS client struct to support TEEPubKey as an object.

* fix: Update sample attestation quote to valid JSON

The default attestation.bin was binary, but the KBS Sample Verifier expects a valid JSON quote containing 'svn' and 'report_data'.
Updated the embedded bin file to contain this JSON structure.

* fix: Generate dynamic JSON quote for Sample TEE in FetchRawEvidence

The KBS Sample Verifier expects a JSON object with 'svn' and 'report_data'.
Previously, we were returning raw binary data (reportData+nonce).
This commit updates FetchRawEvidence to return a marshaled JSON structure with:
- svn: "1"
- report_data: base64(req.ReportData)

* refactor: Delegate Sample Attestation to Provider

Refactored sample attestation logic:
- Moved JSON Quote generation into EmptyProvider (standalone mode).
- Updated FetchRawEvidence to call provider.TeeAttestation instead of manual generation.
This enables using the real CC Attestation Agent for UNSPECIFIED platform if configured.

* feat: Add comprehensive debug logging and enforce CC AA usage

Changes:
- Updated EmptyProvider to return error instead of generating mock data
  This forces proper use of CC Attestation Agent's sample attester
- Added detailed logging to attestation-service FetchRawEvidence:
  * Hex dump of evidence (first 200 bytes)
  * String preview of evidence
  * Total evidence length
- Added detailed logging to agent service:
  * Raw evidence hex and string previews
  * KBS evidence JSON preview (first 500 bytes)
  * Evidence lengths at each transformation step

This logging will help diagnose why KBS Sample Verifier is rejecting evidence.

* fix: Enable CC AA by default and add attestation-service log forwarding

Changes:
- Set USE_CC_ATTESTATION_AGENT=true by default in systemd service
- Added StandardOutput/StandardError to forward logs to /var/log/cocos/
- Updated HAL makefile to handle new default value
- This ensures attestation-service uses CC AA's sample attester
- Logs will now be visible in CVMS output for debugging

* feat: Add gRPC log forwarding to attestation-service

Implemented the same log forwarding mechanism used by the agent:
- Added ProtoHandler to write logs to both stdout and logQueue
- Connected to log client (/run/cocos/log.sock) for gRPC forwarding
- Added goroutine to forward logs to CVMS via log client
- Logs will now appear in CVMS output during computation runs

This enables visibility into attestation-service debug output including:
- CC AA connection status
- Evidence generation details (hex dumps, string previews)
- Any errors from providers

* fix: Parse sample evidence JSON instead of base64-encoding it

The attestation-service returns sample evidence as JSON:
{"svn":"1","report_data":"base64..."}

The agent was incorrectly base64-encoding this JSON string again.
KBS Sample Verifier expects the parsed JSON object directly.

Fixed by:
- Parsing the JSON evidence from attestation-service
- Passing the parsed object directly in primary_evidence.evidence
- This matches what KBS Sample Verifier expects

* debug: Increase KBS evidence logging preview to 1000 bytes

Show the complete JSON structure being sent to KBS to debug
the attestation failure.

* debug: Add comprehensive CC AA configuration logging

Added debug logs to show:
- Whether CC AA is enabled in config
- CC AA address being used
- Connection success/failure
- Which provider is ultimately selected
- Warning when falling back to EmptyProvider

This will help diagnose why EmptyProvider is being used
instead of CC Attestation Agent.

* debug: Add startup logging for log client connection

Added log message to show if log client connection succeeds
at attestation-service startup. This will help diagnose why
logs aren't appearing in CVMS output.

* feat: Add retry logic with exponential backoff to log client

Added simple retry mechanism to handle concurrent log requests:
- 3 retry attempts with exponential backoff (10ms, 20ms, 40ms)
- Applies to both SendLog and SendEvent methods
- Centralized in log client so all services benefit
- Should eliminate 'failed to send log' errors from concurrent requests

This fixes the issue where attestation-service logs weren't
appearing in CVMS output due to dropped messages.

* fix: Flatten sample evidence fields in primary_evidence for KBS

KBS Sample Verifier expects svn and report_data at the top level
of primary_evidence, not nested under an 'evidence' key.

Changed structure from:
{"primary_evidence": {"tee": "sample", "evidence": {"svn": "1", ...}}}

To:
{"primary_evidence": {"tee": "sample", "svn": "1", "report_data": "...", ...}}

This matches what KBS expects when deserializing the Quote structure.

* fix: Use sample quote directly as primary_evidence per KBS protocol

According to KBS attestation protocol spec, for sample TEE type,
primary_evidence should be the sample quote JSON directly:
{"svn": "1", "report_data": "..."}

Removed extra 'tee' and 'platform' fields that were causing KBS
to fail deserializing the Quote structure. The 'tee' field is
already sent in the Request payload during RCAR handshake.

Refs:
- https://github.com/confidential-containers/trustee/blob/main/kbs/docs/kbs_attestation_protocol.md
- https://github.com/confidential-containers/guest-components/blob/main/attestation-agent/attester/src/sample/mod.rs

* fix: Make CC AA required for sample attestation when configured

When USE_CC_ATTESTATION_AGENT=true, attestation-service now
requires AA to be available for NoCC/sample platform. This ensures
sample evidence always comes from AA with the correct KBS format.

Changes:
- Error out if AA connection fails for NoCC platform when AA is configured
- Only use EmptyProvider if AA is explicitly NOT configured
- Prevents incorrect sample evidence format from EmptyProvider

This ensures attestation-service delegates to AA for sample evidence
generation instead of creating it itself.

* fix: Implement proper RCAR protocol with tee-pubkey and runtime-data hash

Fixed KBS attestation error 'REPORT_DATA is different from that in Sample Quote'

Changes:
1. Generate ephemeral EC key pair BEFORE getting evidence from AA
2. Create runtime-data with nonce + tee-pubkey (JWK format)
3. Hash runtime-data (SHA-256) and use as report_data for AA
4. This binds the tee-pubkey to the TEE evidence per RCAR protocol

The report_data in the evidence now matches what KBS expects:
hash(runtime-data) instead of computation ID.

This completes the full RCAR protocol implementation:
- Request → Challenge → Attestation (with bound tee-pubkey) → Response

* fix(agent): use simple nonce for Sample attestation report_data

For Sample/NoCC attestation, use the raw nonce bytes directly as
report_data instead of hashing runtime-data. This avoids JSON
serialization mismatches with the KBS Sample verifier.

Real TEEs (TDX/SNP) still use runtime-data hash binding to
cryptographically bind the ephemeral tee-pubkey to the evidence.

* fix(agent): use RFC 8785 canonical JSON for runtime-data hashing

The KBS Sample attestation verifier (and likely others) expects the
report_data to be the SHA-256 hash of the *canonical* JSON serialization
(RFC 8785) of the runtime-data. Standard Go JSON marshaling does not
guarantee key ordering, leading to hash mismatches.

This change uses github.com/gowebpki/jcs to canonicalize the runtime-data
before hashing, ensuring compatibility with the KBS RCAR implementation.
Also reverted the temporary 'simple nonce' workaround.

* feat(hal): add CoCo Keyprovider and Skopeo packages

- Add coco-keyprovider buildroot package with systemd service
- Add skopeo buildroot package for OCI image handling
- Add ocicrypt_keyprovider.conf for encrypted image decryption
- Update Config.in to include new packages

This enables standard CoCo ecosystem integration for encrypted
OCI images instead of custom S3/HTTP registry clients.

* feat(oci): add OCI image handling package with Skopeo integration

- Add pkg/oci/types.go with ResourceSource and ImageManifest types
- Add pkg/oci/skopeo.go with Skopeo wrapper for pull/decrypt
- Add pkg/oci/extract.go for extracting algorithms and datasets from layers

This package provides OCI image handling using Skopeo and CoCo
Keyprovider for encrypted image decryption, replacing custom
S3/HTTP registry clients.

* chore: regenerate protobuf files for updated cvms.proto

* refactor(agent): replace S3/HTTP/KBS with OCI package

- Remove pkg/kbs and pkg/registry imports
- Add pkg/oci import for OCI image handling
- Replace downloadAndDecryptResource with OCI-based implementation
- Use Skopeo + CoCo Keyprovider for automatic decryption
- Reduce code from ~240 lines to ~70 lines

This eliminates custom KBS RCAR handshake, S3/HTTP registry clients,
and manual decryption logic. CoCo Keyprovider handles all decryption
automatically via ocicrypt protocol.

* chore: remove obsolete pkg/kbs and pkg/registry packages

- Delete pkg/kbs/ (custom KBS client, ~300 lines)
- Delete pkg/registry/ (S3/HTTP registry clients, ~400 lines)
- Remove unused imports from agent/service.go
- Run go mod tidy to clean up dependencies

These packages have been replaced by pkg/oci with Skopeo and
CoCo Keyprovider for standard CoCo ecosystem integration.

* fix(agent): update ResourceSource struct to include type and encryption fields

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix(hal): update CoCo Keyprovider to v0.16.0 and fix build path

- Update version from v0.11.0 to v0.16.0 (matches attestation agent)
- Fix install path: target is at repo root, not in coco_keyprovider subdir
- This fixes the build error where coco_keyprovider binary wasn't found

The cargo workspace in guest-components builds to a shared target/
directory at the repository root, not within each crate's subdirectory.

* feat: Update remote resources testing guide to use kbs-client and coco-keyprovider for key management and encryption, enable insecure TLS for Skopeo, and enhance CVMS with

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update component versions, revise image encryption documentation, and sanitize OCI image paths for Skopeo compatibility.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add `decompress` option to Dataset and `algo_type`/`algo_args` to Algorithm protobuf messages, updating client, test, and build configurations.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Update multiple package versions and enhance OCI image extraction error reporting for missing algorithm files.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Bump package versions, improve OCI image extraction debugging by returning seen files, and remove unused dataset type parsing from test code.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Migrate OCI extraction to use structured logging with `slog` and `context`, and update package versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Bump multiple component versions, add encrypted status for computation inputs and algorithms, and refine OCI layer extraction warnings.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* logging

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add `Encrypted` field to algorithm and dataset resource sources and update all component versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: update component versions, integrate coco-keyprovider service, and configure ocicrypt key provider.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: add support for KBS parameters and dataset/algorithm hash calculations in CVMS

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: update resource download and extraction logic to support requirements.txt and improve hash verification

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Update dependencies, improve code style, and add GetRawEvidence to attestation client mocks.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor code structure for improved readability and maintainability

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: update golangci configuration to include errcheck for build path and remove unnecessary exclusions

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: streamline kernel command line handling in QEMU args construction

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: add attestation binary and update checksum tests and policy structure

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for attestation agent, attestation, log, crypto, OCI, and Skopeo clients

- Implement tests for the attestation agent client including Unix socket and TCP address handling, token retrieval, and error scenarios.
- Enhance attestation client tests to cover fetching raw evidence for various platforms (SNP, TDX, VTPM, SNPvTPM) and validate error handling.
- Introduce log client tests to verify retry behavior for sending logs and events.
- Create comprehensive tests for crypto package focusing on AES-GCM decryption, encrypted resource parsing, and key unwrapping.
- Add tests for OCI package to validate algorithm and dataset extraction, including JSON serialization of OCILayout.
- Implement Skopeo client tests to ensure proper functionality for image pulling, inspecting, and resource source handling.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: handle JSON marshal errors in test cases for decrypt and extract functions

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add comprehensive tests for algorithm and dataset extraction with various scenarios

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: replace hardcoded Python script content with constant variable

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: remove redundant mock expectation for SendAgentConfig in TestCreateVMWithAaKbsParams

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add tests for event sending failure, dataset extraction with path traversal, and Skopeo client behavior

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add tests for download and decryption of resources with various URL formats

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Introduce OCIClient interface for agent service to improve testability of OCI image operations and enhance related tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Change `get_uint64_from_tcb` to accept `TcbVersion` by value and use `u64::from` for type conversions.

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2026-03-16 14:48:55 +01:00

1319 lines
39 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/absmach/certs"
certssdk "github.com/absmach/certs/sdk"
sdkmocks "github.com/absmach/certs/sdk/mocks"
"github.com/absmach/supermq/pkg/errors"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"golang.org/x/crypto/sha3"
"google.golang.org/protobuf/encoding/protojson"
)
const (
sevProductNameMilan = "Milan"
)
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
// mockAttestationClient is a simple mock for testing.
type mockAttestationClient struct {
mock.Mock
}
func (m *mockAttestationClient) GetAttestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
args := m.Called(ctx, reportData, nonce, attType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]byte), args.Error(1)
}
func (m *mockAttestationClient) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
args := m.Called(ctx, reportData, nonce, attType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]byte), args.Error(1)
}
func (m *mockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
args := m.Called(ctx, nonce)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]byte), args.Error(1)
}
func (m *mockAttestationClient) Close() error {
args := m.Called()
return args.Error(0)
}
func generateTestCertPEM(t *testing.T) string {
return generateTestCertPEMWithSubject(t, "test")
}
func generateTestCertPEMWithSubject(t *testing.T, commonName string) string {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: commonName,
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
return strings.ReplaceAll(string(certPEM), "\n", "\\n")
}
func generateTestCertificateWithExtensions(t *testing.T, extensions []pkix.Extension) *x509.Certificate {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "test",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
ExtraExtensions: extensions,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert
}
// TestCertificateSubject tests the CertificateSubject functionality.
func TestDefaultCertificateSubject(t *testing.T) {
subject := DefaultCertificateSubject()
assert.Equal(t, "Ultraviolet", subject.Organization)
assert.Equal(t, "Serbia", subject.Country)
assert.Equal(t, "", subject.Province)
assert.Equal(t, "Belgrade", subject.Locality)
assert.Equal(t, "Bulevar Arsenija Carnojevica 103", subject.StreetAddress)
assert.Equal(t, "11000", subject.PostalCode)
}
// TestUnifiedCertificateGenerator tests the unified certificate generator.
func TestUnifiedCertificateGenerator(t *testing.T) {
t.Run("SelfSignedGenerator", func(t *testing.T) {
generator, err := NewProvider(nil, attestation.SNPvTPM, "", "", nil)
assert.NoError(t, err)
assert.NotNil(t, generator)
})
t.Run("CASignedGenerator", func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
generator, err := NewProvider(nil, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
assert.NoError(t, err)
assert.NotNil(t, generator)
})
}
// TestPlatformAttestationProvider tests the platform attestation provider.
func TestPlatformAttestationProvider(t *testing.T) {
t.Run("NewAttestationProvider", func(t *testing.T) {
mockClient := new(mockAttestationClient)
cases := []struct {
name string
platformType attestation.PlatformType
expectError bool
}{
{"SNPvTPM", attestation.SNPvTPM, false},
{"Azure", attestation.Azure, false},
{"TDX", attestation.TDX, false},
{"Invalid", attestation.PlatformType(999), true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
provider, err := NewAttestationProvider(mockClient, c.platformType)
if c.expectError {
assert.Error(t, err)
assert.Nil(t, provider)
} else {
assert.NoError(t, err)
assert.NotNil(t, provider)
assert.Equal(t, c.platformType, provider.PlatformType())
}
})
}
})
t.Run("GetAttestation", func(t *testing.T) {
mockClient := new(mockAttestationClient)
expectedAttestation := []byte("test-attestation")
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedAttestation, nil)
provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
pubKey := []byte("test-pubkey")
nonce := []byte("test-nonce")
attestation, err := provider.Attest(pubKey, nonce)
assert.NoError(t, err)
assert.Equal(t, expectedAttestation, attestation)
mockClient.AssertExpectations(t)
})
t.Run("GetAttestationError", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
_, err = provider.Attest([]byte("pubkey"), []byte("nonce"))
assert.Error(t, err)
})
}
// TestAttestedCertificateProvider tests the attested certificate provider.
func TestAttestedCertificateProvider(t *testing.T) {
t.Run("GetCertificateSuccess", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
subject := DefaultCertificateSubject()
provider := NewAttestedProvider(attestationProvider, subject)
// Create valid client hello with nonce
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
assert.NoError(t, err)
assert.NotNil(t, cert)
assert.NotEmpty(t, cert.Certificate)
assert.NotNil(t, cert.PrivateKey)
})
t.Run("InvalidServerName", func(t *testing.T) {
mockClient := new(mockAttestationClient)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
clientHello := &tls.ClientHelloInfo{ServerName: "invalid-server-name"}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to extract nonce")
})
t.Run("AttestationError", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get attestation")
})
}
// TestNewProvider tests the factory function.
func TestNewProvider(t *testing.T) {
mockClient := new(mockAttestationClient)
t.Run("SelfSignedProvider", func(t *testing.T) {
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("CASignedProviderWithSDK", func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("SelfSignedProviderNilSDK", func(t *testing.T) {
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", nil)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("InvalidPlatformType", func(t *testing.T) {
_, err := NewProvider(mockClient, attestation.PlatformType(999), "", "", nil)
assert.Error(t, err)
})
}
// TestCertificateVerifier tests certificate verification.
func TestCertificateVerifier(t *testing.T) {
// Setup test policy
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("NewCertificateVerifier", func(t *testing.T) {
rootCAs := x509.NewCertPool()
verifier := certificateVerifier{rootCAs: rootCAs}
assert.Equal(t, rootCAs, verifier.rootCAs)
})
t.Run("VerifyPeerCertificateNoCertificates", func(t *testing.T) {
verifier := NewCertificateVerifier(nil)
err := verifier.VerifyPeerCertificate([][]byte{}, nil, []byte("nonce"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "no certificates provided")
})
t.Run("VerifyPeerCertificateInvalidCert", func(t *testing.T) {
verifier := NewCertificateVerifier(nil)
err := verifier.VerifyPeerCertificate([][]byte{[]byte("invalid")}, nil, []byte("nonce"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse x509 certificate")
})
t.Run("VerifyPeerCertificateNoAttestationExtension", func(t *testing.T) {
cert := createSelfSignedCert(t)
verifier := NewCertificateVerifier(nil)
err := verifier.VerifyPeerCertificate([][]byte{cert.Raw}, nil, []byte("nonce"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
}
// TestExtractNonceFromSNI tests nonce extraction from SNI.
func TestExtractNonceFromSNI(t *testing.T) {
t.Run("ValidNonce", func(t *testing.T) {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
extractedNonce, err := extractNonceFromSNI(serverName)
assert.NoError(t, err)
assert.Equal(t, nonce, extractedNonce)
})
t.Run("InvalidServerName", func(t *testing.T) {
_, err := extractNonceFromSNI("invalid-server-name")
assert.Error(t, err)
})
t.Run("InvalidNonceLength", func(t *testing.T) {
shortNonce := make([]byte, 32) // Too short
serverName := hex.EncodeToString(shortNonce) + ".nonce"
_, err := extractNonceFromSNI(serverName)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid nonce length")
})
t.Run("InvalidHexEncoding", func(t *testing.T) {
serverName := "invalid-hex-encoding.nonce"
_, err := extractNonceFromSNI(serverName)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode nonce")
})
t.Run("MissingNonceSuffix", func(t *testing.T) {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".invalid"
_, err = extractNonceFromSNI(serverName)
assert.Error(t, err)
})
}
// TestHasNonceSuffix tests the nonce suffix checking.
func TestHasNonceSuffix(t *testing.T) {
t.Run("ValidSuffix", func(t *testing.T) {
assert.True(t, hasNonceSuffix("test.nonce"))
})
t.Run("InvalidSuffix", func(t *testing.T) {
assert.False(t, hasNonceSuffix("test.invalid"))
})
t.Run("TooShort", func(t *testing.T) {
assert.False(t, hasNonceSuffix(".non"))
})
t.Run("EmptyString", func(t *testing.T) {
assert.False(t, hasNonceSuffix(""))
})
}
// TestOIDFunctions tests OID-related functions.
func TestPlatformVerifier(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
cases := []struct {
name string
platformType attestation.PlatformType
expectedError bool
}{
{"SNPvTPM", attestation.SNPvTPM, false},
{"Azure", attestation.Azure, false},
{"TDX", attestation.TDX, true}, // Expected error due to policy format
{"Invalid", attestation.PlatformType(999), true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
verifier, err := platformVerifier(c.platformType)
if c.expectedError {
assert.Error(t, err)
assert.Nil(t, verifier)
} else {
assert.NoError(t, err)
assert.NotNil(t, verifier)
}
})
}
}
func TestGetOID(t *testing.T) {
cases := []struct {
name string
platformType attestation.PlatformType
expectedOID asn1.ObjectIdentifier
expectedError bool
}{
{"SNPvTPM", attestation.SNPvTPM, SNPvTPMOID, false},
{"Azure", attestation.Azure, AzureOID, false},
{"TDX", attestation.TDX, TDXOID, false},
{"Invalid", attestation.PlatformType(999), nil, true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
oid, err := OID(c.platformType)
if c.expectedError {
assert.Error(t, err)
assert.Nil(t, oid)
} else {
assert.NoError(t, err)
assert.Equal(t, c.expectedOID, oid)
}
})
}
}
func TestPlatformTypeFromOID(t *testing.T) {
cases := []struct {
name string
oid asn1.ObjectIdentifier
expectedType attestation.PlatformType
expectedError bool
}{
{"SNPvTPM", SNPvTPMOID, attestation.SNPvTPM, false},
{"Azure", AzureOID, attestation.Azure, false},
{"TDX", TDXOID, attestation.TDX, false},
{"Invalid", asn1.ObjectIdentifier{1, 2, 3}, attestation.PlatformType(0), true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
pType, err := platformTypeFromOID(c.oid)
if c.expectedError {
assert.Error(t, err)
assert.Equal(t, attestation.PlatformType(0), pType)
} else {
assert.NoError(t, err)
assert.Equal(t, c.expectedType, pType)
}
})
}
}
// TestVerifyCertificateExtension tests certificate extension verification.
func TestVerifyCertificateExtension(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
teeNonce := append(pubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
cases := []struct {
name string
extension []byte
pubKey []byte
nonce []byte
platformType attestation.PlatformType
expectError bool
}{
{
name: "ValidExtensionSNPvTPM",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true, // Expected due to invalid attestation data
},
{
name: "InvalidPlatformType",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: nonce,
platformType: attestation.PlatformType(999),
expectError: true,
},
{
name: "EmptyExtension",
extension: []byte{},
pubKey: pubKeyDER,
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true,
},
{
name: "EmptyPublicKey",
extension: hashNonce[:],
pubKey: []byte{},
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true,
},
{
name: "EmptyNonce",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: []byte{},
platformType: attestation.SNPvTPM,
expectError: true,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
v := certificateVerifier{}
err := v.verifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
if c.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// Helper functions
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
file, err := os.ReadFile("../../attestation.bin")
require.NoError(t, err)
if len(file) < abi.ReportSize {
file = append(file, make([]byte, abi.ReportSize-len(file))...)
}
rr, err := abi.ReportCertsToProto(file)
require.NoError(t, err)
return rr
}
func setAttestationPolicy(rr *sevsnp.Attestation, policyDirectory string) error {
attestationPolicyFile, err := os.ReadFile("../../scripts/attestation_policy/sev-snp/attestation_policy.json")
if err != nil {
return err
}
unmarshalOptions := protojson.UnmarshalOptions{DiscardUnknown: true}
err = unmarshalOptions.Unmarshal(attestationPolicyFile, policy)
if err != nil {
return err
}
policy.Config.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
policy.Config.Policy.FamilyId = rr.Report.FamilyId
policy.Config.Policy.ImageId = rr.Report.ImageId
policy.Config.Policy.Measurement = rr.Report.Measurement
policy.Config.Policy.HostData = rr.Report.HostData
policy.Config.Policy.ReportIdMa = rr.Report.ReportIdMa
policy.Config.RootOfTrust.ProductLine = sevProductNameMilan
policyByte, err := vtpm.ConvertPolicyToJSON(&policy)
if err != nil {
return err
}
policyPath := filepath.Join(policyDirectory, "attestation_policy.json")
err = os.WriteFile(policyPath, policyByte, 0o644)
if err != nil {
return nil
}
attestation.AttestationPolicyPath = policyPath
return nil
}
// TestCertificateVerification unified test suite for certificate verification.
func TestCertificateVerification(t *testing.T) {
// Setup common test data
selfSignedCert := createSelfSignedCert(t)
leafCert, rootCert := generateCertificateChain(t)
rootCAs := createCertPool(rootCert)
emptyPool := x509.NewCertPool()
t.Run("SelfSignedCertificates", func(t *testing.T) {
testCases := []testCase{
{
name: "ValidSelfSignedCertificate",
cert: selfSignedCert,
rootCAs: nil,
expectError: false,
},
{
name: "EmptyCertificate",
cert: &x509.Certificate{},
rootCAs: nil,
expectError: true,
errorMsg: "x509: missing ASN.1 contents; use ParseCertificate",
},
}
runCertificateVerificationTests(t, testCases)
})
t.Run("CertificateChainVerification", func(t *testing.T) {
testCases := []testCase{
{
name: "ValidCertificateWithRootCA",
cert: leafCert,
rootCAs: rootCAs,
expectError: false,
},
{
name: "SelfSignedCertificate",
cert: rootCert,
rootCAs: nil, // Self-signed verification
expectError: false,
},
{
name: "InvalidCertificateWithEmptyPool",
cert: rootCert,
rootCAs: emptyPool,
expectError: true,
},
}
runCertificateVerificationTests(t, testCases)
})
t.Run("ATLSPeerCertificateVerification", func(t *testing.T) {
nonce := generateNonce(t)
testCases := []atlsTestCase{
{
name: "InvalidCertificateData",
rawCerts: [][]byte{[]byte("invalid cert data")},
nonce: nonce,
rootCAs: rootCAs,
expectError: true,
errorMsg: "failed to parse x509 certificate",
},
{
name: "ValidCertificateNoAttestationExtension",
rawCerts: [][]byte{leafCert.Raw},
nonce: nonce,
rootCAs: rootCAs,
expectError: true,
errorMsg: "attestation extension not found in certificate",
},
}
runATLSVerificationTests(t, testCases)
})
}
// TestAttestedCAProvider tests the CA-signed certificate provider.
func TestAttestedCAProvider(t *testing.T) {
mockClient := new(mockAttestationClient)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
subject := DefaultCertificateSubject()
cvmID := "test-cvm-id"
agentToken := "test-token"
t.Run("NewAttestedCAProvider", func(t *testing.T) {
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
assert.NotNil(t, provider)
})
t.Run("SetTTL", func(t *testing.T) {
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
newTTL := time.Hour * 48
provider.(*attestedCertificateProvider).SetTTL(newTTL)
attestedProvider := provider.(*attestedCertificateProvider)
assert.Equal(t, newTTL, attestedProvider.ttl)
})
}
// TestCASignedCertificateErrors tests error cases in CA-signed certificate generation.
func TestCASignedCertificateErrors(t *testing.T) {
mockClient := new(mockAttestationClient)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
subject := DefaultCertificateSubject()
cvmID := "test-cvm-id"
agentToken := "test-token"
cases := []struct {
name string
certificate string
sdkError error
expectedError string
}{
{"SDKIssueError", "", errors.NewSDKError(errors.New("SDK error")), "SDK error"},
{"InvalidPEMWithRemainingData", "-----BEGIN CERTIFICATE-----\\nVGVzdA==\\n-----END CERTIFICATE-----\\nExtra data here", nil, "unexpected remaining data"},
{"NoPEMBlockFound", "", nil, "no PEM block found"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{Certificate: c.certificate}, c.sdkError)
provider := NewAttestedCAProvider(attestationProvider, subject, mockSDK, cvmID, agentToken)
attestedProvider := provider.(*attestedCertificateProvider)
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
extension := pkix.Extension{
Id: SNPvTPMOID,
Value: []byte("test-data"),
}
_, err = attestedProvider.generateCASignedCertificate(t.Context(), privateKey, extension)
assert.Error(t, err)
assert.Contains(t, err.Error(), c.expectedError)
})
}
}
// TestGetCertificateErrors tests error paths in certificate generation.
func TestGetCertificateErrors(t *testing.T) {
t.Run("InvalidServerNameFormat", func(t *testing.T) {
mockClient := new(mockAttestationClient)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
clientHello := &tls.ClientHelloInfo{
ServerName: "invalid-format",
}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to extract nonce")
})
t.Run("AttestationProviderError", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get attestation")
})
t.Run("CASignedCertificateError", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
sdkErr := errors.NewSDKError(errors.New("CA error"))
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{}, sdkErr)
provider := NewAttestedCAProvider(attestationProvider, DefaultCertificateSubject(), mockSDK, "test-cvm", "test-token")
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to generate certificate")
})
}
// TestCertificateVerificationEdgeCases tests edge cases in certificate verification.
func TestCertificateVerificationEdgeCases(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("VerifyPeerCertificateWithMultipleCerts", func(t *testing.T) {
verifier := NewCertificateVerifier(nil)
cert1 := createSelfSignedCert(t)
cert2 := createSelfSignedCert(t)
nonce := generateNonce(t)
err := verifier.VerifyPeerCertificate([][]byte{cert1.Raw, cert2.Raw}, nil, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyAttestationExtensionWithNoExtensions", func(t *testing.T) {
cert := createSelfSignedCert(t)
verifier := certificateVerifier{}
nonce := generateNonce(t)
err := verifier.verifyAttestationExtension(cert, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyAttestationExtensionWithWrongOID", func(t *testing.T) {
wrongOID := asn1.ObjectIdentifier{1, 2, 3, 4, 5}
extension := pkix.Extension{
Id: wrongOID,
Value: []byte("test-data"),
}
cert := generateTestCertificateWithExtensions(t, []pkix.Extension{extension})
verifier := certificateVerifier{}
nonce := generateNonce(t)
err := verifier.verifyAttestationExtension(cert, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyCertificateExtensionPlatformVerifierError", func(t *testing.T) {
verifier := certificateVerifier{}
invalidPlatformType := attestation.PlatformType(999)
err := verifier.verifyCertificateExtension([]byte("test-extension"), []byte("test-pubkey"), []byte("test-nonce"), invalidPlatformType)
assert.Error(t, err)
// The error occurs during EAT token decoding before platform type validation
assert.Contains(t, err.Error(), "failed to decode EAT token")
})
}
// TestCertificateWithAttestationExtension tests certificates with attestation extensions.
func TestCertificateWithAttestationExtension(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("CertificateWithValidAttestationExtension", func(t *testing.T) {
// Create certificate with attestation extension
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
_, err = x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
extension := pkix.Extension{
Id: SNPvTPMOID,
Value: []byte("test-attestation-data"),
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
ExtraExtensions: []pkix.Extension{extension},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
verifier := certificateVerifier{}
err = verifier.verifyAttestationExtension(cert, nonce)
// Expect error due to invalid attestation data, but extension should be found
assert.Error(t, err)
assert.NotContains(t, err.Error(), "attestation extension not found")
})
}
// TestIntegrationScenarios tests end-to-end integration scenarios.
func TestIntegrationScenarios(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("FullSelfSignedFlow", func(t *testing.T) {
// Setup mock client
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
// Create provider
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
require.NoError(t, err)
// Generate certificate
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
assert.NoError(t, err)
assert.NotNil(t, cert)
assert.NotEmpty(t, cert.Certificate)
assert.NotNil(t, cert.PrivateKey)
// Verify the generated certificate
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
// Check for attestation extension
found := false
for _, ext := range parsedCert.Extensions {
if ext.Id.Equal(SNPvTPMOID) {
found = true
break
}
}
assert.True(t, found, "Attestation extension should be present")
})
t.Run("FullCASignedFlow", func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
expectedCert := certssdk.Certificate{Certificate: generateTestCertPEM(t)}
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedCert, errors.SDKError(nil))
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
require.NoError(t, err)
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
require.NoError(t, err)
require.NotNil(t, cert)
require.NotEmpty(t, cert.Certificate)
require.NotNil(t, cert.PrivateKey)
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
assert.NotNil(t, parsedCert.Subject)
mockClient.AssertExpectations(t)
mockSDK.AssertExpectations(t)
})
}
// TestConcurrentAccess tests concurrent access scenarios.
func TestConcurrentAccess(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
require.NoError(t, err)
const numGoroutines = 10
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
if err != nil {
errors <- err
return
}
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
if err != nil {
errors <- err
return
}
if cert == nil {
errors <- fmt.Errorf("nil certificate returned for goroutine %d", id)
return
}
errors <- nil
}(i)
}
// Collect results
for i := 0; i < numGoroutines; i++ {
err := <-errors
assert.NoError(t, err)
}
}
// TestEdgeCasesAndBoundaries tests edge cases and boundary conditions.
func TestEdgeCasesAndBoundaries(t *testing.T) {
t.Run("LargeNonce", func(t *testing.T) {
largeNonce := make([]byte, 1024) // Much larger than expected
_, err := rand.Read(largeNonce)
require.NoError(t, err)
serverName := hex.EncodeToString(largeNonce) + ".nonce"
_, err = extractNonceFromSNI(serverName)
assert.Error(t, err) // Should fail due to invalid length
})
t.Run("MaxLengthServerName", func(t *testing.T) {
// Create very long server name
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
longPrefix := strings.Repeat("a", 200)
serverName := longPrefix + hex.EncodeToString(nonce) + ".nonce"
_, err = extractNonceFromSNI(serverName)
assert.Error(t, err) // Should fail due to invalid format
})
t.Run("MinimalValidNonce", func(t *testing.T) {
nonce := make([]byte, 64) // Exactly the required length
_, err := rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
extractedNonce, err := extractNonceFromSNI(serverName)
assert.NoError(t, err)
assert.Equal(t, nonce, extractedNonce)
})
}
// Unified test case structures.
type testCase struct {
name string
cert *x509.Certificate
rootCAs *x509.CertPool
expectError bool
errorMsg string
}
type atlsTestCase struct {
name string
rawCerts [][]byte
nonce []byte
rootCAs *x509.CertPool
expectError bool
errorMsg string
}
// Unified test runners.
func runCertificateVerificationTests(t *testing.T, testCases []testCase) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v := certificateVerifier{
rootCAs: tc.rootCAs,
}
err := v.verifyCertificateSignature(tc.cert)
if tc.expectError {
assert.Error(t, err)
if tc.errorMsg != "" {
if tc.errorMsg == "x509: missing ASN.1 contents; use ParseCertificate" {
// For specific error matching
assert.True(t, errors.Contains(err, errors.New(tc.errorMsg)),
fmt.Sprintf("expected error %q, got %v", tc.errorMsg, err))
} else {
assert.Contains(t, err.Error(), tc.errorMsg)
}
}
} else {
assert.NoError(t, err)
}
})
}
}
func runATLSVerificationTests(t *testing.T, testCases []atlsTestCase) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v := certificateVerifier{
rootCAs: tc.rootCAs,
}
err := v.VerifyPeerCertificate(tc.rawCerts, nil, tc.nonce)
if tc.expectError {
assert.Error(t, err)
if tc.errorMsg != "" {
assert.Contains(t, err.Error(), tc.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
// Unified certificate creation utilities.
func createSelfSignedCert(t *testing.T) *x509.Certificate {
privateKey := generateRSAKey(t)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
return createCertificateFromTemplate(t, &template, &template, &privateKey.PublicKey, privateKey)
}
func generateCertificateChain(t *testing.T) (leafCert, rootCert *x509.Certificate) {
// Generate root certificate
rootKey := generateRSAKey(t)
rootTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Root CA"},
Country: []string{"US"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
rootCert = createCertificateFromTemplate(t, &rootTemplate, &rootTemplate, &rootKey.PublicKey, rootKey)
// Generate leaf certificate signed by root
leafKey := generateRSAKey(t)
leafTemplate := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Leaf"},
Country: []string{"US"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
leafCert = createCertificateFromTemplate(t, &leafTemplate, &rootTemplate, &leafKey.PublicKey, rootKey)
return leafCert, rootCert
}
// Helper functions for consistency.
func generateRSAKey(t *testing.T) *rsa.PrivateKey {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
return privateKey
}
func createCertificateFromTemplate(t *testing.T, template, parent *x509.Certificate, pub interface{}, priv interface{}) *x509.Certificate {
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert
}
func createCertPool(certs ...*x509.Certificate) *x509.CertPool {
pool := x509.NewCertPool()
for _, cert := range certs {
pool.AddCert(cert)
}
return pool
}
func generateNonce(t *testing.T) []byte {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
return nonce
}