NOISSUE - Azure TDX Support (#596)

* initial Azure TDX support

* add tests

* update documentation

---------

Co-authored-by: Ubuntu <danko@cocos.nbzvzgavv4yeximq0jorvcggfd.dx.internal.cloudapp.net>
This commit is contained in:
Danko Miladinovic
2026-05-25 12:22:29 +02:00
committed by GitHub
parent 27db9b29eb
commit 02aa7d7d85
11 changed files with 1302 additions and 3 deletions
+8
View File
@@ -21,12 +21,20 @@ The service is configured using the environment variables from the following tab
| AGENT_CVM_ID | Unique identifier for the CVM (Confidential Virtual Machine) | "" |
| AGENT_CERTS_TOKEN | Authentication token for certificate service access | "" |
| AGENT_MAA_URL | Microsoft Azure Attestation service URL for Azure attestation | https://sharedeus2.eus2.attest.azure.net |
| AZURE_TDX_IMDS_URL | Azure TDX quote endpoint used by direct Azure TDX attestation | http://169.254.169.254/acc/tdquote |
| AZURE_HCL_REFRESH_WAIT | Wait after writing TDX report data to Azure HCL vTPM storage before reading the refreshed HCL report | 3s |
| AGENT_OS_BUILD | Operating system build information for attestation | UVC |
| AGENT_OS_DISTRO | Operating system distribution information for attestation | UVC |
| AGENT_OS_TYPE | Operating system type information for attestation | UVC |
| ATTESTATION_SERVICE_SOCKET | Unix socket path for attestation service communication | /run/cocos/attestation.sock |
| AGENT_ENABLE_ATLS | Enable Attestation TLS for secure communication | true |
### Azure TDX Attestation
When the agent runs on an Azure TDX CVM, Azure attestation uses the direct Azure TDX flow. The agent writes TDX report data to Azure HCL vTPM storage, reads the refreshed HCL report, requests a TD quote from Azure IMDS, and submits the quote plus HCL runtime data to Microsoft Azure Attestation. This path does not depend on Confidential Containers attestation-agent `GetEvidence` or KBS token retrieval.
`AGENT_MAA_URL` selects the Microsoft Azure Attestation endpoint. `AZURE_TDX_IMDS_URL` can override the Azure IMDS TDX quote endpoint, and `AZURE_HCL_REFRESH_WAIT` controls the wait used to avoid reading a stale HCL report after report-data is written.
### Remote Resource Download (Optional)
The agent supports downloading encrypted algorithms and datasets from remote registries (S3, HTTP/HTTPS) and retrieving decryption keys from a Key Broker Service (KBS) via attestation.
+1 -1
View File
@@ -60,7 +60,7 @@ func (req azureAttestationTokenReq) validate() error {
func validateAttestationType(attType attestation.PlatformType) error {
switch attType {
case attestation.SNP, attestation.VTPM, attestation.SNPvTPM, attestation.TDX:
case attestation.SNP, attestation.VTPM, attestation.SNPvTPM, attestation.Azure, attestation.TDX:
return nil
default:
return errors.New("invalid attestation type")
@@ -33,6 +33,12 @@ func (s *service) FetchRawEvidence(ctx context.Context, req *attestationpb.Attes
var nonce [32]byte
copy(nonce[:], req.Nonce)
binaryReport, err = s.provider.Attestation(reportData[:], nonce[:])
case attestationpb.PlatformType_PLATFORM_TYPE_AZURE:
var reportData [64]byte
copy(reportData[:], req.ReportData)
var nonce [32]byte
copy(nonce[:], req.Nonce)
binaryReport, err = s.provider.Attestation(reportData[:], nonce[:])
case attestationpb.PlatformType_PLATFORM_TYPE_UNSPECIFIED:
// Generate sample attestation for testing in non-TEE environments
// This uses the underlying provider (EmptyProvider or CC Attestation Agent)
+7
View File
@@ -318,6 +318,13 @@ func (s *service) FetchAttestation(ctx context.Context, req *attestationpb.Attes
copy(nonce[:], req.Nonce)
binaryReport, err = s.provider.Attestation(reportData[:], nonce[:])
platformType = attestation.SNPvTPM
case attestationpb.PlatformType_PLATFORM_TYPE_AZURE:
var reportData [64]byte
copy(reportData[:], req.ReportData)
var nonce [32]byte
copy(nonce[:], req.Nonce)
binaryReport, err = s.provider.Attestation(reportData[:], nonce[:])
platformType = attestation.Azure
case attestationpb.PlatformType_PLATFORM_TYPE_UNSPECIFIED:
// Generate sample attestation for testing in non-TEE environments
s.logger.Warn("generating sample attestation for PLATFORM_TYPE_UNSPECIFIED - this should only be used for testing")
+25
View File
@@ -4,6 +4,10 @@
package azure
import (
"os"
"strings"
"time"
"github.com/edgelesssys/go-azguestattestation/maa"
)
@@ -28,6 +32,27 @@ func InitializeDefaultMAAVars(config *EnvConfig) {
maa.OSType = config.OSType
maa.OSDistro = config.OSDistro
MaaURL = config.MaaURL
InitializeDefaultAzureTDXVarsFromEnv()
}
func InitializeDefaultAzureTDXVars(imdsURL string, hclRefreshDelay time.Duration) {
if imdsURL = strings.TrimSpace(imdsURL); imdsURL != "" {
azureTDXIMDSQuoteURL = imdsURL
}
if hclRefreshDelay >= 0 {
azureTDXHCLRefreshDelay = hclRefreshDelay
}
}
func InitializeDefaultAzureTDXVarsFromEnv() {
imdsURL := os.Getenv("AZURE_TDX_IMDS_URL")
hclRefreshDelay := azureTDXHCLRefreshDelay
if value := strings.TrimSpace(os.Getenv("AZURE_HCL_REFRESH_WAIT")); value != "" {
if parsed, err := time.ParseDuration(value); err == nil {
hclRefreshDelay = parsed
}
}
InitializeDefaultAzureTDXVars(imdsURL, hclRefreshDelay)
}
func (c *EnvConfig) InitializeOSVars(build, osType, osDistro string) {
+32
View File
@@ -5,6 +5,7 @@ package azure
import (
"testing"
"time"
"github.com/edgelesssys/go-azguestattestation/maa"
"github.com/stretchr/testify/assert"
@@ -56,3 +57,34 @@ func TestInitializeOSVars(t *testing.T) {
assert.Equal(t, "TypeY", cfg.OSType)
assert.Equal(t, "DistroZ", cfg.OSDistro)
}
func TestInitializeDefaultAzureTDXVars(t *testing.T) {
oldURL := azureTDXIMDSQuoteURL
oldDelay := azureTDXHCLRefreshDelay
defer func() {
azureTDXIMDSQuoteURL = oldURL
azureTDXHCLRefreshDelay = oldDelay
}()
InitializeDefaultAzureTDXVars(" https://imds.example/tdquote ", 1500*time.Millisecond)
assert.Equal(t, "https://imds.example/tdquote", azureTDXIMDSQuoteURL)
assert.Equal(t, 1500*time.Millisecond, azureTDXHCLRefreshDelay)
}
func TestInitializeDefaultAzureTDXVarsFromEnv(t *testing.T) {
oldURL := azureTDXIMDSQuoteURL
oldDelay := azureTDXHCLRefreshDelay
defer func() {
azureTDXIMDSQuoteURL = oldURL
azureTDXHCLRefreshDelay = oldDelay
}()
t.Setenv("AZURE_TDX_IMDS_URL", "https://env-imds.example/tdquote")
t.Setenv("AZURE_HCL_REFRESH_WAIT", "2s")
InitializeDefaultAzureTDXVarsFromEnv()
assert.Equal(t, "https://env-imds.example/tdquote", azureTDXIMDSQuoteURL)
assert.Equal(t, 2*time.Second, azureTDXHCLRefreshDelay)
}
+24 -2
View File
@@ -52,6 +52,10 @@ func NewProvider() attestation.Provider {
}
func (a provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error) {
if isAzureTDX() {
return a.TeeAttestation(teeNonce)
}
var tokenNonce [vtpm.Nonce]byte
copy(tokenNonce[:], teeNonce)
@@ -77,6 +81,10 @@ func (a provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
}
func (a provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
if isAzureTDX() {
return fetchAzureTDXQuote(teeNonce)
}
var tokenNonce [vtpm.Nonce]byte
copy(tokenNonce[:], teeNonce)
@@ -89,7 +97,6 @@ func (a provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
}
func (a provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
fmt.Printf("DEBUG: VTpmAttestation: vtpm.ExternalTPM is %T at %p\n", vtpm.ExternalTPM, &vtpm.ExternalTPM)
quote, err := vtpm.FetchQuote(vTpmNonce)
if err != nil {
return []byte{}, errors.Wrap(vtpm.ErrFetchQuote, err)
@@ -111,6 +118,14 @@ func (c *defaultMaaClient) Attest(ctx context.Context, nonce []byte, maaURL stri
var DefaultMaaClient MaaClient = &defaultMaaClient{}
func (a provider) AzureAttestationToken(tokenNonce []byte) ([]byte, error) {
if isAzureTDX() {
token, err := FetchAzureTDXAttestationToken(tokenNonce, MaaURL)
if err != nil {
return nil, errors.Wrap(ErrFetchAzureToken, err)
}
return token, nil
}
token, err := DefaultMaaClient.Attest(context.Background(), tokenNonce, MaaURL, http.DefaultClient)
if err != nil {
return nil, errors.Wrap(ErrFetchAzureToken, err)
@@ -152,12 +167,19 @@ func (v verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte)
func (v verifier) VerifyWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
attestation := &attest.Attestation{}
if err := proto.Unmarshal(report, attestation); err != nil {
return fmt.Errorf("failed to unmarshal attestation report: %w", err)
tdxErr := verifyTDXQuoteWithCoRIM(report, manifest)
if tdxErr == nil {
return nil
}
return fmt.Errorf("failed to unmarshal attestation report: %w; Azure TDX verification failed: %v", err, tdxErr)
}
// Extract measurement from SEV-SNP report if present
snpRep := attestation.GetSevSnpAttestation()
if snpRep == nil {
if tdxErr := verifyTDXQuoteWithCoRIM(report, manifest); tdxErr == nil {
return nil
}
return fmt.Errorf("no SEV-SNP attestation found in report")
}
+672
View File
@@ -0,0 +1,672 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package azure
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
tdxabi "github.com/google/go-tdx-guest/abi"
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
"github.com/veraison/corim/comid"
"github.com/veraison/corim/corim"
)
const (
tdxAttestEndpoint = "attest/TdxVm"
tdxAPIVersion = "2025-06-01"
tdxRuntimeBinary = "Binary"
tdxRuntimeJSON = "JSON"
azureHCLReportNVIndex = 0x01400001
azureHCLReportDataNVIndex = 0x01400002
azureHCLSignature = "HCLA"
azureHCLVersion = 2
azureHCLRequestType = 2
azureHCLRuntimeDataVersion = 1
azureHCLHashSHA256 = 1
azureHCLReportTypeSNP = 2
azureHCLReportTypeTDX = 4
azureHCLHeaderSize = 0x20
azureHCLMaxHWReportSize = 0x4a0
azureHCLRuntimeDataOffset = azureHCLHeaderSize + azureHCLMaxHWReportSize
azureHCLRuntimeClaimsOffset = 0x14
azureTDReportSize = 0x400
azureTDReportDataOffset = 0x80
)
var (
azureTDXHCLReportReader = readAzureHCLReport
azureTDXReportDataWriter = writeAzureTDXReportData
azureTDXHCLRefreshDelay = 3 * time.Second
azureTDXIMDSQuoteURL = "http://169.254.169.254/acc/tdquote"
)
// TDXQuoteFetcher fetches a raw TDX quote for the provided REPORT_DATA.
type TDXQuoteFetcher interface {
FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error)
}
type TDXEvidenceFetcher interface {
FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error)
}
type defaultTDXQuoteFetcher struct{}
func (f defaultTDXQuoteFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
evidence, err := f.FetchEvidence(reportData)
if err != nil {
return nil, err
}
return evidence.Quote, nil
}
func (f defaultTDXQuoteFetcher) FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error) {
hclReport, err := readFreshAzureTDXHCLReport(reportData)
if err != nil {
return nil, err
}
parsedReport, err := parseAzureHCLReport(hclReport)
if err != nil {
return nil, err
}
if parsedReport.reportType != azureHCLReportTypeTDX {
return nil, fmt.Errorf("Azure HCL report is not TDX")
}
if err := validateAzureTDXRuntimeClaimsHash(parsedReport); err != nil {
return nil, err
}
quote, err := DefaultAzureTDXIMDSClient.GetQuote(context.Background(), parsedReport.hwReport, http.DefaultClient)
if err != nil {
return nil, fmt.Errorf("failed to get Azure TDX quote from IMDS: %w", err)
}
return &azureTDXEvidence{
Quote: quote,
RuntimeData: append([]byte(nil), parsedReport.runtimeClaims...),
}, nil
}
// DefaultTDXQuoteFetcher is used by the Azure TDX provider and is replaceable in tests.
var DefaultTDXQuoteFetcher TDXQuoteFetcher = defaultTDXQuoteFetcher{}
// AzureTDXIMDSClient fetches an Azure TDX quote from the Azure Instance Metadata Service.
type AzureTDXIMDSClient interface {
GetQuote(ctx context.Context, tdReport []byte, client *http.Client) ([]byte, error)
}
type defaultAzureTDXIMDSClient struct{}
// DefaultAzureTDXIMDSClient is used by the Azure TDX quote fetcher and is replaceable in tests.
var DefaultAzureTDXIMDSClient AzureTDXIMDSClient = &defaultAzureTDXIMDSClient{}
// AzureTDXClient submits Azure TDX VM attestation evidence to Microsoft Azure Attestation.
type AzureTDXClient interface {
AttestTDXVM(ctx context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, client *http.Client) (string, error)
}
type defaultAzureTDXClient struct{}
// DefaultAzureTDXClient is used by Azure TDX token fetching and is replaceable in tests.
var DefaultAzureTDXClient AzureTDXClient = &defaultAzureTDXClient{}
type tdxAttestRequest struct {
Quote string `json:"quote"`
RuntimeData *tdxDataBlob `json:"runtimeData,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
type tdxDataBlob struct {
Data string `json:"data"`
DataType string `json:"dataType"`
}
type tdxAttestResponse struct {
Token string `json:"token"`
}
type tdxIMDSQuoteRequest struct {
Report string `json:"report"`
}
type tdxIMDSQuoteResponse struct {
Quote string `json:"quote"`
}
type azureTDXEvidence struct {
Quote []byte
RuntimeData []byte
}
type azureHCLReport struct {
reportType uint32
hashType uint32
hwReport []byte
runtimeClaims []byte
}
func isAzureTDX() bool {
if azureTDXHCLReportReader == nil {
return false
}
hclReport, err := azureTDXHCLReportReader()
if err != nil {
return false
}
parsedReport, err := parseAzureHCLReport(hclReport)
return err == nil && parsedReport.reportType == azureHCLReportTypeTDX
}
func fetchAzureTDXQuote(teeNonce []byte) ([]byte, error) {
if teeNonce == nil {
return nil, fmt.Errorf("tee nonce is required for Azure TDX attestation")
}
if len(teeNonce) != tdxabi.ReportDataSize {
return nil, fmt.Errorf("invalid tee nonce length: expected %d bytes, got %d bytes", tdxabi.ReportDataSize, len(teeNonce))
}
var reportData [tdxabi.ReportDataSize]byte
copy(reportData[:], teeNonce)
evidence, err := fetchAzureTDXEvidence(reportData, teeNonce)
if err != nil {
return nil, err
}
return evidence.Quote, nil
}
func (c *defaultAzureTDXClient) AttestTDXVM(ctx context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, client *http.Client) (string, error) {
if maaURL == "" {
return "", fmt.Errorf("maaURL is empty")
}
if client == nil {
client = http.DefaultClient
}
maaURL, err := url.JoinPath(maaURL, tdxAttestEndpoint)
if err != nil {
return "", fmt.Errorf("parsing maaURL: %w", err)
}
maaURL += fmt.Sprintf("?api-version=%s", tdxAPIVersion)
attestRequest := tdxAttestRequest{
Quote: base64.RawURLEncoding.EncodeToString(quote),
}
if len(runtimeData) > 0 {
attestRequest.RuntimeData = &tdxDataBlob{
Data: base64.RawURLEncoding.EncodeToString(runtimeData),
DataType: tdxRuntimeDataType(runtimeData),
}
}
if len(nonce) > 0 {
attestRequest.Nonce = base64.RawURLEncoding.EncodeToString(nonce)
}
reqBytes, err := json.Marshal(attestRequest)
if err != nil {
return "", fmt.Errorf("marshaling TDX attestation request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, maaURL, bytes.NewReader(reqBytes))
if err != nil {
return "", fmt.Errorf("creating TDX attestation request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("doing TDX attestation request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if msg, err := io.ReadAll(resp.Body); err == nil && len(msg) > 0 {
return "", fmt.Errorf("MAA returned %v: %s", resp.Status, msg)
}
return "", fmt.Errorf("MAA returned %v", resp.Status)
}
var attestResponse tdxAttestResponse
if err := json.NewDecoder(resp.Body).Decode(&attestResponse); err != nil {
return "", fmt.Errorf("decoding TDX attestation response: %w", err)
}
if attestResponse.Token == "" {
return "", fmt.Errorf("azure TDX attestation token not found in response")
}
return attestResponse.Token, nil
}
func (c *defaultAzureTDXIMDSClient) GetQuote(ctx context.Context, tdReport []byte, client *http.Client) ([]byte, error) {
if len(tdReport) != azureTDReportSize {
return nil, fmt.Errorf("invalid TD report length: expected %d bytes, got %d bytes", azureTDReportSize, len(tdReport))
}
if client == nil {
client = http.DefaultClient
}
quoteRequest := tdxIMDSQuoteRequest{
Report: base64.RawURLEncoding.EncodeToString(tdReport),
}
reqBytes, err := json.Marshal(quoteRequest)
if err != nil {
return nil, fmt.Errorf("marshaling Azure TDX IMDS quote request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, azureTDXIMDSQuoteURL, bytes.NewReader(reqBytes))
if err != nil {
return nil, fmt.Errorf("creating Azure TDX IMDS quote request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("doing Azure TDX IMDS quote request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if msg, err := io.ReadAll(resp.Body); err == nil && len(msg) > 0 {
return nil, fmt.Errorf("Azure TDX IMDS returned %v: %s", resp.Status, msg)
}
return nil, fmt.Errorf("Azure TDX IMDS returned %v", resp.Status)
}
var quoteResponse tdxIMDSQuoteResponse
if err := json.NewDecoder(resp.Body).Decode(&quoteResponse); err != nil {
return nil, fmt.Errorf("decoding Azure TDX IMDS quote response: %w", err)
}
if quoteResponse.Quote == "" {
return nil, fmt.Errorf("Azure TDX IMDS quote not found in response")
}
quote, err := decodeBase64URL(quoteResponse.Quote)
if err != nil {
return nil, fmt.Errorf("decoding Azure TDX IMDS quote: %w", err)
}
return quote, nil
}
// FetchAzureTDXAttestationToken fetches an Azure Attestation token for an Azure TDX VM.
func FetchAzureTDXAttestationToken(tokenNonce []byte, maaURL string) ([]byte, error) {
reportData := tdxReportDataFromRuntimeData(tokenNonce)
evidence, err := fetchAzureTDXEvidence(reportData, tokenNonce)
if err != nil {
return nil, fmt.Errorf("failed to fetch Azure TDX quote: %w", err)
}
token, err := DefaultAzureTDXClient.AttestTDXVM(context.Background(), evidence.Quote, evidence.RuntimeData, tokenNonce, maaURL, http.DefaultClient)
if err != nil {
return nil, fmt.Errorf("error fetching azure TDX token: %w", err)
}
return []byte(token), nil
}
func tdxReportDataFromRuntimeData(runtimeData []byte) [tdxabi.ReportDataSize]byte {
hash := sha256.Sum256(runtimeData)
var reportData [tdxabi.ReportDataSize]byte
copy(reportData[:sha256.Size], hash[:])
return reportData
}
func fetchAzureTDXEvidence(reportData [tdxabi.ReportDataSize]byte, fallbackRuntimeData []byte) (*azureTDXEvidence, error) {
if evidenceFetcher, ok := DefaultTDXQuoteFetcher.(TDXEvidenceFetcher); ok {
return evidenceFetcher.FetchEvidence(reportData)
}
quote, err := DefaultTDXQuoteFetcher.FetchQuote(reportData)
if err != nil {
return nil, err
}
return &azureTDXEvidence{
Quote: quote,
RuntimeData: append([]byte(nil), fallbackRuntimeData...),
}, nil
}
func readFreshAzureTDXHCLReport(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
if azureTDXReportDataWriter != nil {
if err := azureTDXReportDataWriter(reportData[:]); err != nil {
return nil, fmt.Errorf("writing Azure TDX report data: %w", err)
}
if azureTDXHCLRefreshDelay > 0 {
time.Sleep(azureTDXHCLRefreshDelay)
}
}
if azureTDXHCLReportReader == nil {
return nil, fmt.Errorf("Azure TDX HCL report reader is not configured")
}
hclReport, err := azureTDXHCLReportReader()
if err != nil {
return nil, fmt.Errorf("reading Azure TDX HCL report: %w", err)
}
return hclReport, nil
}
func readAzureHCLReport() ([]byte, error) {
tpm, err := tpm2.OpenTPM()
if err != nil {
return nil, err
}
defer tpm.Close()
return tpm2.NVReadEx(tpm, azureHCLReportNVIndex, tpm2.HandleOwner, "", 0)
}
func writeAzureTDXReportData(data []byte) error {
if len(data) != tdxabi.ReportDataSize {
return fmt.Errorf("invalid Azure TDX report data length: expected %d bytes, got %d bytes", tdxabi.ReportDataSize, len(data))
}
tpm, err := tpm2.OpenTPM()
if err != nil {
return err
}
defer tpm.Close()
return writeAzureTDXReportDataToTPM(tpm, data)
}
func writeAzureTDXReportDataToTPM(tpm io.ReadWriter, data []byte) error {
if len(data) > int(^uint16(0)) {
return fmt.Errorf("Azure TDX report data is too large")
}
if err := ensureAzureTDXReportDataIndex(tpm, uint16(len(data))); err != nil {
return err
}
if err := tpm2.NVWrite(tpm, tpm2.HandleOwner, azureHCLReportDataNVIndex, "", tpmutil.U16Bytes(data), 0); err != nil {
return fmt.Errorf("writing Azure TDX report-data NV index: %w", err)
}
return nil
}
func ensureAzureTDXReportDataIndex(tpm io.ReadWriter, size uint16) error {
pub, err := tpm2.NVReadPublic(tpm, azureHCLReportDataNVIndex)
if err == nil {
if pub.DataSize == size {
return nil
}
if err := tpm2.NVUndefineSpace(tpm, "", tpm2.HandleOwner, azureHCLReportDataNVIndex); err != nil {
return fmt.Errorf("undefining mismatched Azure TDX report-data NV index: %w", err)
}
}
nvPub := tpm2.NVPublic{
NVIndex: azureHCLReportDataNVIndex,
NameAlg: tpm2.AlgSHA256,
Attributes: tpm2.AttrOwnerWrite | tpm2.AttrOwnerRead,
DataSize: size,
}
authArea := tpm2.AuthCommand{
Session: tpm2.HandlePasswordSession,
Attributes: tpm2.AttrContinueSession,
Auth: []byte(""),
}
if err := tpm2.NVDefineSpaceEx(tpm, tpm2.HandleOwner, "", nvPub, authArea); err != nil {
return fmt.Errorf("defining Azure TDX report-data NV index: %w", err)
}
return nil
}
func parseAzureHCLReport(report []byte) (*azureHCLReport, error) {
minSize := azureHCLRuntimeDataOffset + azureHCLRuntimeClaimsOffset
if len(report) < minSize {
return nil, fmt.Errorf("invalid Azure HCL report size: expected at least %d bytes, got %d bytes", minSize, len(report))
}
if string(report[:len(azureHCLSignature)]) != azureHCLSignature {
return nil, fmt.Errorf("invalid Azure HCL report signature")
}
if version := binary.LittleEndian.Uint32(report[4:8]); version != azureHCLVersion {
return nil, fmt.Errorf("invalid Azure HCL report version: expected %d, got %d", azureHCLVersion, version)
}
reportSize := binary.LittleEndian.Uint32(report[8:12])
if reportSize > uint32(len(report)) {
return nil, fmt.Errorf("invalid Azure HCL report size: header reports %d bytes, got %d bytes", reportSize, len(report))
}
if requestType := binary.LittleEndian.Uint32(report[12:16]); requestType != azureHCLRequestType {
return nil, fmt.Errorf("invalid Azure HCL report request type: expected %d, got %d", azureHCLRequestType, requestType)
}
runtimeData := report[azureHCLRuntimeDataOffset:]
dataSize := binary.LittleEndian.Uint32(runtimeData[0:4])
if dataSize < azureHCLRuntimeClaimsOffset {
return nil, fmt.Errorf("invalid Azure HCL runtime data size: %d", dataSize)
}
if azureHCLRuntimeDataOffset+int(dataSize) > len(report) {
return nil, fmt.Errorf("invalid Azure HCL runtime data size: header reports %d bytes", dataSize)
}
if version := binary.LittleEndian.Uint32(runtimeData[4:8]); version != azureHCLRuntimeDataVersion {
return nil, fmt.Errorf("invalid Azure HCL runtime data version: expected %d, got %d", azureHCLRuntimeDataVersion, version)
}
reportType := binary.LittleEndian.Uint32(runtimeData[8:12])
if reportType != azureHCLReportTypeSNP && reportType != azureHCLReportTypeTDX {
return nil, fmt.Errorf("invalid Azure HCL report type: %d", reportType)
}
hashType := binary.LittleEndian.Uint32(runtimeData[12:16])
claimsSize := binary.LittleEndian.Uint32(runtimeData[16:20])
claimsEnd := azureHCLRuntimeClaimsOffset + int(claimsSize)
if claimsEnd > int(dataSize) {
return nil, fmt.Errorf("invalid Azure HCL runtime claims size: %d", claimsSize)
}
hwReportSize := azureHCLMaxHWReportSize
if reportType == azureHCLReportTypeTDX {
hwReportSize = azureTDReportSize
}
if azureHCLHeaderSize+hwReportSize > len(report) {
return nil, fmt.Errorf("invalid Azure HCL hardware report size: %d", hwReportSize)
}
return &azureHCLReport{
reportType: reportType,
hashType: hashType,
hwReport: append([]byte(nil), report[azureHCLHeaderSize:azureHCLHeaderSize+hwReportSize]...),
runtimeClaims: append([]byte(nil), runtimeData[azureHCLRuntimeClaimsOffset:claimsEnd]...),
}, nil
}
func validateAzureTDXRuntimeClaimsHash(report *azureHCLReport) error {
if report.hashType != azureHCLHashSHA256 {
return fmt.Errorf("unsupported Azure HCL runtime data hash type: %d", report.hashType)
}
if len(report.hwReport) < azureTDReportDataOffset+sha256.Size {
return fmt.Errorf("invalid Azure TDX TD report size: %d", len(report.hwReport))
}
hash := sha256.Sum256(report.runtimeClaims)
if !bytes.Equal(hash[:], report.hwReport[azureTDReportDataOffset:azureTDReportDataOffset+sha256.Size]) {
return fmt.Errorf("Azure TDX runtime claims hash does not match TD report data")
}
return nil
}
func tdxRuntimeDataType(runtimeData []byte) string {
if json.Valid(runtimeData) {
return tdxRuntimeJSON
}
return tdxRuntimeBinary
}
func decodeBase64URL(value string) ([]byte, error) {
decoded, err := base64.RawURLEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
return base64.URLEncoding.DecodeString(value)
}
func verifyTDXQuoteWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
decodedQuote, err := tdxabi.QuoteToProto(report)
if err != nil {
return fmt.Errorf("failed to parse TDX quote: %w", err)
}
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
if !ok {
return fmt.Errorf("unsupported TDX quote format")
}
tdReport := quoteV4.GetTdQuoteBody()
if tdReport == nil {
return fmt.Errorf("missing TDX quote body")
}
mrtd := tdReport.GetMrTd()
if len(mrtd) == 0 {
return fmt.Errorf("no MRTD in TDX quote")
}
if err := matchMeasurementInCoRIM(manifest, mrtd); err != nil {
return fmt.Errorf("%w for Azure TDX", err)
}
return nil
}
func matchMeasurementInCoRIM(manifest *corim.UnsignedCorim, measurement []byte) error {
if manifest == nil || len(manifest.Tags) == 0 {
return fmt.Errorf("no tags in CoRIM")
}
for _, tag := range manifest.Tags {
if !bytes.HasPrefix(tag, corim.ComidTag) {
continue
}
tagValue := tag[len(corim.ComidTag):]
var c comid.Comid
if err := c.FromCBOR(tagValue); err != nil {
return fmt.Errorf("failed to parse CoMID: %w", err)
}
if c.Triples.ReferenceValues == nil {
continue
}
for _, rv := range *c.Triples.ReferenceValues {
for _, m := range rv.Measurements {
if m.Val.Digests == nil {
continue
}
for _, digest := range *m.Val.Digests {
if bytes.Equal(digest.HashValue, measurement) {
return nil
}
}
}
}
}
return fmt.Errorf("no matching reference value found in CoRIM")
}
// AzureTDXMeasurementData contains the fields extracted from an Azure TDX attestation token
// needed to construct a CoRIM policy for the TDX platform.
type AzureTDXMeasurementData struct {
MRTD string
MRSEAM string
RTMRs []string
SEAMSVN uint64
}
// ExtractAzureTDXMeasurement extracts core TDX measurements from an Azure Attestation token.
func ExtractAzureTDXMeasurement(token string) (*AzureTDXMeasurementData, error) {
claims, err := DefaultValidator.Validate(token)
if err != nil {
return nil, fmt.Errorf("failed to validate token: %w", err)
}
mrtd, ok := azureClaimString(claims, "tdx_mrtd")
if !ok {
return nil, fmt.Errorf("failed to get MRTD from claims")
}
mrSeam, _ := azureClaimString(claims, "tdx_mrseam")
rtmrs := make([]string, 0, 4)
for _, name := range []string{"tdx_rtmr0", "tdx_rtmr1", "tdx_rtmr2", "tdx_rtmr3"} {
if value, ok := azureClaimString(claims, name); ok {
rtmrs = append(rtmrs, value)
}
}
seamSVN, _ := azureClaimUint64(claims, "tdx_seamsvn")
return &AzureTDXMeasurementData{
MRTD: mrtd,
MRSEAM: mrSeam,
RTMRs: rtmrs,
SEAMSVN: seamSVN,
}, nil
}
func azureClaimString(claims map[string]any, name string) (string, bool) {
if value, ok := claims[name].(string); ok {
return value, true
}
tee, ok := claims["x-ms-isolation-tee"].(map[string]any)
if !ok {
return "", false
}
value, ok := tee[name].(string)
return value, ok
}
func azureClaimUint64(claims map[string]any, name string) (uint64, bool) {
value, ok := claims[name]
if !ok {
tee, teeOK := claims["x-ms-isolation-tee"].(map[string]any)
if !teeOK {
return 0, false
}
value, ok = tee[name]
if !ok {
return 0, false
}
}
switch typed := value.(type) {
case float64:
return uint64(typed), true
case int:
return uint64(typed), true
case uint64:
return typed, true
default:
return 0, false
}
}
+455
View File
@@ -0,0 +1,455 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package azure
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
tdxabi "github.com/google/go-tdx-guest/abi"
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
tdxtestdata "github.com/google/go-tdx-guest/testing/testdata"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/veraison/corim/comid"
"github.com/veraison/corim/corim"
"github.com/veraison/swid"
)
type mockTDXQuoteFetcher struct {
quote []byte
err error
gotReportData [tdxabi.ReportDataSize]byte
fetchQuoteCall bool
}
func (m *mockTDXQuoteFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
m.fetchQuoteCall = true
m.gotReportData = reportData
if m.err != nil {
return nil, m.err
}
return m.quote, nil
}
type mockTDXEvidenceFetcher struct {
evidence *azureTDXEvidence
err error
gotReportData [tdxabi.ReportDataSize]byte
fetchEvidenceHit bool
}
func (m *mockTDXEvidenceFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
evidence, err := m.FetchEvidence(reportData)
if err != nil {
return nil, err
}
return evidence.Quote, nil
}
func (m *mockTDXEvidenceFetcher) FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error) {
m.fetchEvidenceHit = true
m.gotReportData = reportData
if m.err != nil {
return nil, m.err
}
return m.evidence, nil
}
type mockAzureTDXClient struct {
token string
err error
gotQuote []byte
gotRuntime []byte
gotNonce []byte
gotMaaURL string
attestCalls int
}
func (m *mockAzureTDXClient) AttestTDXVM(_ context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, _ *http.Client) (string, error) {
m.attestCalls++
m.gotQuote = append([]byte(nil), quote...)
m.gotRuntime = append([]byte(nil), runtimeData...)
m.gotNonce = append([]byte(nil), nonce...)
m.gotMaaURL = maaURL
if m.err != nil {
return "", m.err
}
return m.token, nil
}
type mockAzureTDXIMDSClient struct {
quote []byte
err error
gotTDReport []byte
getQuoteCall bool
}
func (m *mockAzureTDXIMDSClient) GetQuote(_ context.Context, tdReport []byte, _ *http.Client) ([]byte, error) {
m.getQuoteCall = true
m.gotTDReport = append([]byte(nil), tdReport...)
if m.err != nil {
return nil, m.err
}
return m.quote, nil
}
func testAzureHCLReport(reportType uint32, runtimeClaims []byte) []byte {
reportSize := azureHCLRuntimeDataOffset + azureHCLRuntimeClaimsOffset + len(runtimeClaims)
hclReport := make([]byte, reportSize)
copy(hclReport[:len(azureHCLSignature)], azureHCLSignature)
binary.LittleEndian.PutUint32(hclReport[4:8], azureHCLVersion)
binary.LittleEndian.PutUint32(hclReport[8:12], uint32(reportSize))
binary.LittleEndian.PutUint32(hclReport[12:16], azureHCLRequestType)
runtimeData := hclReport[azureHCLRuntimeDataOffset:]
binary.LittleEndian.PutUint32(runtimeData[0:4], uint32(azureHCLRuntimeClaimsOffset+len(runtimeClaims)))
binary.LittleEndian.PutUint32(runtimeData[4:8], azureHCLRuntimeDataVersion)
binary.LittleEndian.PutUint32(runtimeData[8:12], reportType)
binary.LittleEndian.PutUint32(runtimeData[12:16], azureHCLHashSHA256)
binary.LittleEndian.PutUint32(runtimeData[16:20], uint32(len(runtimeClaims)))
copy(runtimeData[azureHCLRuntimeClaimsOffset:], runtimeClaims)
if reportType == azureHCLReportTypeTDX {
hash := sha256.Sum256(runtimeClaims)
copy(hclReport[azureHCLHeaderSize+azureTDReportDataOffset:], hash[:])
}
return hclReport
}
func TestProvider_TeeAttestation_AzureTDX(t *testing.T) {
oldReader := azureTDXHCLReportReader
oldFetcher := DefaultTDXQuoteFetcher
defer func() {
azureTDXHCLReportReader = oldReader
DefaultTDXQuoteFetcher = oldFetcher
}()
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
}
fetcher := &mockTDXQuoteFetcher{quote: []byte("tdx-quote")}
DefaultTDXQuoteFetcher = fetcher
reportData := bytes.Repeat([]byte{0xAB}, tdxabi.ReportDataSize)
got, err := NewProvider().TeeAttestation(reportData)
require.NoError(t, err)
assert.Equal(t, []byte("tdx-quote"), got)
assert.True(t, fetcher.fetchQuoteCall)
assert.Equal(t, reportData, fetcher.gotReportData[:])
}
func TestProvider_TeeAttestation_AzureTDX_InvalidNonce(t *testing.T) {
oldReader := azureTDXHCLReportReader
defer func() { azureTDXHCLReportReader = oldReader }()
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
}
_, err := NewProvider().TeeAttestation([]byte("short"))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid tee nonce length")
}
func TestProvider_AzureAttestationToken_AzureTDX(t *testing.T) {
oldReader := azureTDXHCLReportReader
oldFetcher := DefaultTDXQuoteFetcher
oldClient := DefaultAzureTDXClient
oldMaaURL := MaaURL
defer func() {
azureTDXHCLReportReader = oldReader
DefaultTDXQuoteFetcher = oldFetcher
DefaultAzureTDXClient = oldClient
MaaURL = oldMaaURL
}()
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
}
MaaURL = "https://tdx.example.attest.azure.net"
fetcher := &mockTDXQuoteFetcher{quote: []byte("quote")}
client := &mockAzureTDXClient{token: "tdx-token"}
DefaultTDXQuoteFetcher = fetcher
DefaultAzureTDXClient = client
nonce := []byte("token-nonce")
got, err := NewProvider().AzureAttestationToken(nonce)
require.NoError(t, err)
assert.Equal(t, []byte("tdx-token"), got)
expectedReportData := tdxReportDataFromRuntimeData(nonce)
assert.Equal(t, expectedReportData, fetcher.gotReportData)
assert.Equal(t, []byte("quote"), client.gotQuote)
assert.Equal(t, nonce, client.gotRuntime)
assert.Equal(t, nonce, client.gotNonce)
assert.Equal(t, MaaURL, client.gotMaaURL)
}
func TestFetchAzureTDXAttestationToken_UsesHCLRuntimeClaims(t *testing.T) {
oldFetcher := DefaultTDXQuoteFetcher
oldClient := DefaultAzureTDXClient
defer func() {
DefaultTDXQuoteFetcher = oldFetcher
DefaultAzureTDXClient = oldClient
}()
runtimeClaims := []byte(`{"keys":[],"user-data":"nonce"}`)
fetcher := &mockTDXEvidenceFetcher{
evidence: &azureTDXEvidence{
Quote: []byte("quote"),
RuntimeData: runtimeClaims,
},
}
client := &mockAzureTDXClient{token: "tdx-token"}
DefaultTDXQuoteFetcher = fetcher
DefaultAzureTDXClient = client
nonce := []byte("token-nonce")
got, err := FetchAzureTDXAttestationToken(nonce, "https://tdx.example.attest.azure.net")
require.NoError(t, err)
assert.Equal(t, []byte("tdx-token"), got)
assert.True(t, fetcher.fetchEvidenceHit)
assert.Equal(t, tdxReportDataFromRuntimeData(nonce), fetcher.gotReportData)
assert.Equal(t, runtimeClaims, client.gotRuntime)
assert.Equal(t, nonce, client.gotNonce)
}
func TestFetchAzureTDXAttestationToken_FetchQuoteError(t *testing.T) {
oldFetcher := DefaultTDXQuoteFetcher
defer func() { DefaultTDXQuoteFetcher = oldFetcher }()
DefaultTDXQuoteFetcher = &mockTDXQuoteFetcher{err: fmt.Errorf("quote unavailable")}
_, err := FetchAzureTDXAttestationToken([]byte("nonce"), "https://tdx.example.attest.azure.net")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to fetch Azure TDX quote")
}
func TestDefaultTDXQuoteFetcher_FetchEvidence_AzureHCLIMDS(t *testing.T) {
oldReader := azureTDXHCLReportReader
oldWriter := azureTDXReportDataWriter
oldDelay := azureTDXHCLRefreshDelay
oldIMDSClient := DefaultAzureTDXIMDSClient
defer func() {
azureTDXHCLReportReader = oldReader
azureTDXReportDataWriter = oldWriter
azureTDXHCLRefreshDelay = oldDelay
DefaultAzureTDXIMDSClient = oldIMDSClient
}()
runtimeClaims := []byte(`{"keys":[],"user-data":"fresh"}`)
hclReport := testAzureHCLReport(azureHCLReportTypeTDX, runtimeClaims)
var gotReportData []byte
azureTDXReportDataWriter = func(data []byte) error {
gotReportData = append([]byte(nil), data...)
return nil
}
azureTDXHCLReportReader = func() ([]byte, error) {
return hclReport, nil
}
azureTDXHCLRefreshDelay = 0
imdsClient := &mockAzureTDXIMDSClient{quote: []byte("tdx-quote")}
DefaultAzureTDXIMDSClient = imdsClient
reportData := [tdxabi.ReportDataSize]byte{0xAB}
evidence, err := defaultTDXQuoteFetcher{}.FetchEvidence(reportData)
require.NoError(t, err)
assert.Equal(t, reportData[:], gotReportData)
assert.Equal(t, []byte("tdx-quote"), evidence.Quote)
assert.Equal(t, runtimeClaims, evidence.RuntimeData)
assert.True(t, imdsClient.getQuoteCall)
assert.Len(t, imdsClient.gotTDReport, azureTDReportSize)
}
func TestDefaultAzureTDXClient_AttestTDXVM(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/attest/TdxVm", r.URL.Path)
assert.Equal(t, tdxAPIVersion, r.URL.Query().Get("api-version"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
var req tdxAttestRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("quote")), req.Quote)
require.NotNil(t, req.RuntimeData)
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("runtime")), req.RuntimeData.Data)
assert.Equal(t, tdxRuntimeBinary, req.RuntimeData.DataType)
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("nonce")), req.Nonce)
_, _ = w.Write([]byte(`{"token":"tdx-token"}`))
}))
defer server.Close()
token, err := (&defaultAzureTDXClient{}).AttestTDXVM(
context.Background(),
[]byte("quote"),
[]byte("runtime"),
[]byte("nonce"),
server.URL,
server.Client(),
)
require.NoError(t, err)
assert.Equal(t, "tdx-token", token)
}
func TestDefaultAzureTDXClient_AttestTDXVM_JSONRuntimeData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req tdxAttestRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
require.NotNil(t, req.RuntimeData)
assert.Equal(t, tdxRuntimeJSON, req.RuntimeData.DataType)
_, _ = w.Write([]byte(`{"token":"tdx-token"}`))
}))
defer server.Close()
_, err := (&defaultAzureTDXClient{}).AttestTDXVM(
context.Background(),
[]byte("quote"),
[]byte(`{"keys":[]}`),
[]byte("nonce"),
server.URL,
server.Client(),
)
require.NoError(t, err)
}
func TestDefaultAzureTDXClient_AttestTDXVM_ErrorStatus(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad quote", http.StatusBadRequest)
}))
defer server.Close()
_, err := (&defaultAzureTDXClient{}).AttestTDXVM(context.Background(), []byte("quote"), nil, nil, server.URL, server.Client())
require.Error(t, err)
assert.Contains(t, err.Error(), "MAA returned 400 Bad Request")
}
func TestDefaultAzureTDXIMDSClient_GetQuote(t *testing.T) {
oldURL := azureTDXIMDSQuoteURL
defer func() { azureTDXIMDSQuoteURL = oldURL }()
tdReport := bytes.Repeat([]byte{0xA5}, azureTDReportSize)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
var req tdxIMDSQuoteRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, base64.RawURLEncoding.EncodeToString(tdReport), req.Report)
resp := tdxIMDSQuoteResponse{
Quote: base64.RawURLEncoding.EncodeToString([]byte("quote")),
}
require.NoError(t, json.NewEncoder(w).Encode(resp))
}))
defer server.Close()
azureTDXIMDSQuoteURL = server.URL
quote, err := (&defaultAzureTDXIMDSClient{}).GetQuote(context.Background(), tdReport, server.Client())
require.NoError(t, err)
assert.Equal(t, []byte("quote"), quote)
}
func TestIsAzureTDX_UsesHCLReportType(t *testing.T) {
oldReader := azureTDXHCLReportReader
defer func() { azureTDXHCLReportReader = oldReader }()
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeSNP, []byte("runtime")), nil
}
assert.False(t, isAzureTDX())
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
}
assert.True(t, isAzureTDX())
azureTDXHCLReportReader = func() ([]byte, error) {
return nil, fmt.Errorf("no vTPM")
}
assert.False(t, isAzureTDX())
}
func TestExtractAzureTDXMeasurement_Success(t *testing.T) {
oldValidator := DefaultValidator
defer func() { DefaultValidator = oldValidator }()
DefaultValidator = &mockTokenValidator{
validateFunc: func(token string) (map[string]any, error) {
return map[string]any{
"tdx_mrtd": "mrtd",
"tdx_mrseam": "mrseam",
"tdx_rtmr0": "rtmr0",
"tdx_rtmr1": "rtmr1",
"tdx_rtmr2": "rtmr2",
"tdx_rtmr3": "rtmr3",
"tdx_seamsvn": float64(7),
"unrelatedKey": "ignored",
}, nil
},
}
data, err := ExtractAzureTDXMeasurement("valid-token")
require.NoError(t, err)
assert.Equal(t, &AzureTDXMeasurementData{
MRTD: "mrtd",
MRSEAM: "mrseam",
RTMRs: []string{"rtmr0", "rtmr1", "rtmr2", "rtmr3"},
SEAMSVN: 7,
}, data)
}
func TestVerifier_VerifyWithCoRIM_AzureTDX(t *testing.T) {
decodedQuote, err := tdxabi.QuoteToProto(tdxtestdata.RawQuote)
require.NoError(t, err)
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
require.True(t, ok)
mrtd := quoteV4.GetTdQuoteBody().GetMrTd()
c := comid.NewComid()
c.SetTagIdentity("tdx-tag", 0)
m := comid.MustNewUintMeasurement(uint64(1))
m.AddDigest(swid.Sha384, mrtd)
m.SetRawValueBytes([]byte("raw"), nil)
rv := comid.ReferenceValue{
Environment: comid.Environment{
Class: comid.NewClassOID("1.2.3.4"),
},
Measurements: comid.Measurements{*m},
}
c.AddReferenceValue(rv)
manifest := corim.NewUnsignedCorim()
manifest.SetID("test-tdx-corim")
manifest.AddComid(*c)
err = NewVerifier(&bytes.Buffer{}).VerifyWithCoRIM(tdxtestdata.RawQuote, manifest)
assert.NoError(t, err)
}
+4
View File
@@ -55,6 +55,8 @@ func (c *client) GetAttestation(ctx context.Context, reportData [64]byte, nonce
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
case attestation.SNPvTPM:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
case attestation.Azure:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_AZURE
default:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
}
@@ -92,6 +94,8 @@ func (c *client) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
case attestation.SNPvTPM:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
case attestation.Azure:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_AZURE
default:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
}
@@ -174,6 +174,40 @@ func TestGetAttestationTDX(t *testing.T) {
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
}
// TestGetAttestationAzure tests getting Azure attestation.
func TestGetAttestationAzure(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-azure-evidence.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.Azure)
require.NoError(t, err)
assert.NotNil(t, quote)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_AZURE, mockServer.lastPlatformType)
}
// TestGetAttestationVTPM tests getting vTPM attestation.
func TestGetAttestationVTPM(t *testing.T) {
tmpDir := t.TempDir()
@@ -479,6 +513,40 @@ func TestGetRawEvidenceTDX(t *testing.T) {
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
}
// TestGetRawEvidenceAzure tests getting raw evidence for Azure platform.
func TestGetRawEvidenceAzure(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "raw-evidence-azure.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.Azure)
require.NoError(t, err)
assert.NotNil(t, evidence)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_AZURE, mockServer.lastPlatformType)
}
// TestGetRawEvidenceVTPM tests getting raw evidence for VTPM platform.
func TestGetRawEvidenceVTPM(t *testing.T) {
tmpDir := t.TempDir()