mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Azure TDX Support (#596)
* initial Azure TDX support * add tests * update documentation --------- Co-authored-by: Ubuntu <danko@cocos.nbzvzgavv4yeximq0jorvcggfd.dx.internal.cloudapp.net>
This commit is contained in:
committed by
GitHub
parent
27db9b29eb
commit
02aa7d7d85
@@ -4,6 +4,10 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/go-azguestattestation/maa"
|
||||
)
|
||||
|
||||
@@ -28,6 +32,27 @@ func InitializeDefaultMAAVars(config *EnvConfig) {
|
||||
maa.OSType = config.OSType
|
||||
maa.OSDistro = config.OSDistro
|
||||
MaaURL = config.MaaURL
|
||||
InitializeDefaultAzureTDXVarsFromEnv()
|
||||
}
|
||||
|
||||
func InitializeDefaultAzureTDXVars(imdsURL string, hclRefreshDelay time.Duration) {
|
||||
if imdsURL = strings.TrimSpace(imdsURL); imdsURL != "" {
|
||||
azureTDXIMDSQuoteURL = imdsURL
|
||||
}
|
||||
if hclRefreshDelay >= 0 {
|
||||
azureTDXHCLRefreshDelay = hclRefreshDelay
|
||||
}
|
||||
}
|
||||
|
||||
func InitializeDefaultAzureTDXVarsFromEnv() {
|
||||
imdsURL := os.Getenv("AZURE_TDX_IMDS_URL")
|
||||
hclRefreshDelay := azureTDXHCLRefreshDelay
|
||||
if value := strings.TrimSpace(os.Getenv("AZURE_HCL_REFRESH_WAIT")); value != "" {
|
||||
if parsed, err := time.ParseDuration(value); err == nil {
|
||||
hclRefreshDelay = parsed
|
||||
}
|
||||
}
|
||||
InitializeDefaultAzureTDXVars(imdsURL, hclRefreshDelay)
|
||||
}
|
||||
|
||||
func (c *EnvConfig) InitializeOSVars(build, osType, osDistro string) {
|
||||
|
||||
@@ -5,6 +5,7 @@ package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/go-azguestattestation/maa"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -56,3 +57,34 @@ func TestInitializeOSVars(t *testing.T) {
|
||||
assert.Equal(t, "TypeY", cfg.OSType)
|
||||
assert.Equal(t, "DistroZ", cfg.OSDistro)
|
||||
}
|
||||
|
||||
func TestInitializeDefaultAzureTDXVars(t *testing.T) {
|
||||
oldURL := azureTDXIMDSQuoteURL
|
||||
oldDelay := azureTDXHCLRefreshDelay
|
||||
defer func() {
|
||||
azureTDXIMDSQuoteURL = oldURL
|
||||
azureTDXHCLRefreshDelay = oldDelay
|
||||
}()
|
||||
|
||||
InitializeDefaultAzureTDXVars(" https://imds.example/tdquote ", 1500*time.Millisecond)
|
||||
|
||||
assert.Equal(t, "https://imds.example/tdquote", azureTDXIMDSQuoteURL)
|
||||
assert.Equal(t, 1500*time.Millisecond, azureTDXHCLRefreshDelay)
|
||||
}
|
||||
|
||||
func TestInitializeDefaultAzureTDXVarsFromEnv(t *testing.T) {
|
||||
oldURL := azureTDXIMDSQuoteURL
|
||||
oldDelay := azureTDXHCLRefreshDelay
|
||||
defer func() {
|
||||
azureTDXIMDSQuoteURL = oldURL
|
||||
azureTDXHCLRefreshDelay = oldDelay
|
||||
}()
|
||||
|
||||
t.Setenv("AZURE_TDX_IMDS_URL", "https://env-imds.example/tdquote")
|
||||
t.Setenv("AZURE_HCL_REFRESH_WAIT", "2s")
|
||||
|
||||
InitializeDefaultAzureTDXVarsFromEnv()
|
||||
|
||||
assert.Equal(t, "https://env-imds.example/tdquote", azureTDXIMDSQuoteURL)
|
||||
assert.Equal(t, 2*time.Second, azureTDXHCLRefreshDelay)
|
||||
}
|
||||
|
||||
@@ -52,6 +52,10 @@ func NewProvider() attestation.Provider {
|
||||
}
|
||||
|
||||
func (a provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error) {
|
||||
if isAzureTDX() {
|
||||
return a.TeeAttestation(teeNonce)
|
||||
}
|
||||
|
||||
var tokenNonce [vtpm.Nonce]byte
|
||||
copy(tokenNonce[:], teeNonce)
|
||||
|
||||
@@ -77,6 +81,10 @@ func (a provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
func (a provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
|
||||
if isAzureTDX() {
|
||||
return fetchAzureTDXQuote(teeNonce)
|
||||
}
|
||||
|
||||
var tokenNonce [vtpm.Nonce]byte
|
||||
copy(tokenNonce[:], teeNonce)
|
||||
|
||||
@@ -89,7 +97,6 @@ func (a provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (a provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
|
||||
fmt.Printf("DEBUG: VTpmAttestation: vtpm.ExternalTPM is %T at %p\n", vtpm.ExternalTPM, &vtpm.ExternalTPM)
|
||||
quote, err := vtpm.FetchQuote(vTpmNonce)
|
||||
if err != nil {
|
||||
return []byte{}, errors.Wrap(vtpm.ErrFetchQuote, err)
|
||||
@@ -111,6 +118,14 @@ func (c *defaultMaaClient) Attest(ctx context.Context, nonce []byte, maaURL stri
|
||||
var DefaultMaaClient MaaClient = &defaultMaaClient{}
|
||||
|
||||
func (a provider) AzureAttestationToken(tokenNonce []byte) ([]byte, error) {
|
||||
if isAzureTDX() {
|
||||
token, err := FetchAzureTDXAttestationToken(tokenNonce, MaaURL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFetchAzureToken, err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
token, err := DefaultMaaClient.Attest(context.Background(), tokenNonce, MaaURL, http.DefaultClient)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFetchAzureToken, err)
|
||||
@@ -152,12 +167,19 @@ func (v verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte)
|
||||
func (v verifier) VerifyWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
|
||||
attestation := &attest.Attestation{}
|
||||
if err := proto.Unmarshal(report, attestation); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal attestation report: %w", err)
|
||||
tdxErr := verifyTDXQuoteWithCoRIM(report, manifest)
|
||||
if tdxErr == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to unmarshal attestation report: %w; Azure TDX verification failed: %v", err, tdxErr)
|
||||
}
|
||||
|
||||
// Extract measurement from SEV-SNP report if present
|
||||
snpRep := attestation.GetSevSnpAttestation()
|
||||
if snpRep == nil {
|
||||
if tdxErr := verifyTDXQuoteWithCoRIM(report, manifest); tdxErr == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("no SEV-SNP attestation found in report")
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,672 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
tdxabi "github.com/google/go-tdx-guest/abi"
|
||||
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
|
||||
"github.com/google/go-tpm/legacy/tpm2"
|
||||
"github.com/google/go-tpm/tpmutil"
|
||||
"github.com/veraison/corim/comid"
|
||||
"github.com/veraison/corim/corim"
|
||||
)
|
||||
|
||||
const (
|
||||
tdxAttestEndpoint = "attest/TdxVm"
|
||||
tdxAPIVersion = "2025-06-01"
|
||||
tdxRuntimeBinary = "Binary"
|
||||
tdxRuntimeJSON = "JSON"
|
||||
|
||||
azureHCLReportNVIndex = 0x01400001
|
||||
azureHCLReportDataNVIndex = 0x01400002
|
||||
|
||||
azureHCLSignature = "HCLA"
|
||||
azureHCLVersion = 2
|
||||
azureHCLRequestType = 2
|
||||
azureHCLRuntimeDataVersion = 1
|
||||
azureHCLHashSHA256 = 1
|
||||
azureHCLReportTypeSNP = 2
|
||||
azureHCLReportTypeTDX = 4
|
||||
|
||||
azureHCLHeaderSize = 0x20
|
||||
azureHCLMaxHWReportSize = 0x4a0
|
||||
azureHCLRuntimeDataOffset = azureHCLHeaderSize + azureHCLMaxHWReportSize
|
||||
azureHCLRuntimeClaimsOffset = 0x14
|
||||
azureTDReportSize = 0x400
|
||||
azureTDReportDataOffset = 0x80
|
||||
)
|
||||
|
||||
var (
|
||||
azureTDXHCLReportReader = readAzureHCLReport
|
||||
azureTDXReportDataWriter = writeAzureTDXReportData
|
||||
azureTDXHCLRefreshDelay = 3 * time.Second
|
||||
azureTDXIMDSQuoteURL = "http://169.254.169.254/acc/tdquote"
|
||||
)
|
||||
|
||||
// TDXQuoteFetcher fetches a raw TDX quote for the provided REPORT_DATA.
|
||||
type TDXQuoteFetcher interface {
|
||||
FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type TDXEvidenceFetcher interface {
|
||||
FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error)
|
||||
}
|
||||
|
||||
type defaultTDXQuoteFetcher struct{}
|
||||
|
||||
func (f defaultTDXQuoteFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
|
||||
evidence, err := f.FetchEvidence(reportData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return evidence.Quote, nil
|
||||
}
|
||||
|
||||
func (f defaultTDXQuoteFetcher) FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error) {
|
||||
hclReport, err := readFreshAzureTDXHCLReport(reportData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedReport, err := parseAzureHCLReport(hclReport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedReport.reportType != azureHCLReportTypeTDX {
|
||||
return nil, fmt.Errorf("Azure HCL report is not TDX")
|
||||
}
|
||||
if err := validateAzureTDXRuntimeClaimsHash(parsedReport); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
quote, err := DefaultAzureTDXIMDSClient.GetQuote(context.Background(), parsedReport.hwReport, http.DefaultClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get Azure TDX quote from IMDS: %w", err)
|
||||
}
|
||||
|
||||
return &azureTDXEvidence{
|
||||
Quote: quote,
|
||||
RuntimeData: append([]byte(nil), parsedReport.runtimeClaims...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DefaultTDXQuoteFetcher is used by the Azure TDX provider and is replaceable in tests.
|
||||
var DefaultTDXQuoteFetcher TDXQuoteFetcher = defaultTDXQuoteFetcher{}
|
||||
|
||||
// AzureTDXIMDSClient fetches an Azure TDX quote from the Azure Instance Metadata Service.
|
||||
type AzureTDXIMDSClient interface {
|
||||
GetQuote(ctx context.Context, tdReport []byte, client *http.Client) ([]byte, error)
|
||||
}
|
||||
|
||||
type defaultAzureTDXIMDSClient struct{}
|
||||
|
||||
// DefaultAzureTDXIMDSClient is used by the Azure TDX quote fetcher and is replaceable in tests.
|
||||
var DefaultAzureTDXIMDSClient AzureTDXIMDSClient = &defaultAzureTDXIMDSClient{}
|
||||
|
||||
// AzureTDXClient submits Azure TDX VM attestation evidence to Microsoft Azure Attestation.
|
||||
type AzureTDXClient interface {
|
||||
AttestTDXVM(ctx context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, client *http.Client) (string, error)
|
||||
}
|
||||
|
||||
type defaultAzureTDXClient struct{}
|
||||
|
||||
// DefaultAzureTDXClient is used by Azure TDX token fetching and is replaceable in tests.
|
||||
var DefaultAzureTDXClient AzureTDXClient = &defaultAzureTDXClient{}
|
||||
|
||||
type tdxAttestRequest struct {
|
||||
Quote string `json:"quote"`
|
||||
RuntimeData *tdxDataBlob `json:"runtimeData,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
type tdxDataBlob struct {
|
||||
Data string `json:"data"`
|
||||
DataType string `json:"dataType"`
|
||||
}
|
||||
|
||||
type tdxAttestResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type tdxIMDSQuoteRequest struct {
|
||||
Report string `json:"report"`
|
||||
}
|
||||
|
||||
type tdxIMDSQuoteResponse struct {
|
||||
Quote string `json:"quote"`
|
||||
}
|
||||
|
||||
type azureTDXEvidence struct {
|
||||
Quote []byte
|
||||
RuntimeData []byte
|
||||
}
|
||||
|
||||
type azureHCLReport struct {
|
||||
reportType uint32
|
||||
hashType uint32
|
||||
hwReport []byte
|
||||
runtimeClaims []byte
|
||||
}
|
||||
|
||||
func isAzureTDX() bool {
|
||||
if azureTDXHCLReportReader == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
hclReport, err := azureTDXHCLReportReader()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
parsedReport, err := parseAzureHCLReport(hclReport)
|
||||
return err == nil && parsedReport.reportType == azureHCLReportTypeTDX
|
||||
}
|
||||
|
||||
func fetchAzureTDXQuote(teeNonce []byte) ([]byte, error) {
|
||||
if teeNonce == nil {
|
||||
return nil, fmt.Errorf("tee nonce is required for Azure TDX attestation")
|
||||
}
|
||||
if len(teeNonce) != tdxabi.ReportDataSize {
|
||||
return nil, fmt.Errorf("invalid tee nonce length: expected %d bytes, got %d bytes", tdxabi.ReportDataSize, len(teeNonce))
|
||||
}
|
||||
|
||||
var reportData [tdxabi.ReportDataSize]byte
|
||||
copy(reportData[:], teeNonce)
|
||||
|
||||
evidence, err := fetchAzureTDXEvidence(reportData, teeNonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return evidence.Quote, nil
|
||||
}
|
||||
|
||||
func (c *defaultAzureTDXClient) AttestTDXVM(ctx context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, client *http.Client) (string, error) {
|
||||
if maaURL == "" {
|
||||
return "", fmt.Errorf("maaURL is empty")
|
||||
}
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
|
||||
maaURL, err := url.JoinPath(maaURL, tdxAttestEndpoint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parsing maaURL: %w", err)
|
||||
}
|
||||
maaURL += fmt.Sprintf("?api-version=%s", tdxAPIVersion)
|
||||
|
||||
attestRequest := tdxAttestRequest{
|
||||
Quote: base64.RawURLEncoding.EncodeToString(quote),
|
||||
}
|
||||
if len(runtimeData) > 0 {
|
||||
attestRequest.RuntimeData = &tdxDataBlob{
|
||||
Data: base64.RawURLEncoding.EncodeToString(runtimeData),
|
||||
DataType: tdxRuntimeDataType(runtimeData),
|
||||
}
|
||||
}
|
||||
if len(nonce) > 0 {
|
||||
attestRequest.Nonce = base64.RawURLEncoding.EncodeToString(nonce)
|
||||
}
|
||||
|
||||
reqBytes, err := json.Marshal(attestRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling TDX attestation request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, maaURL, bytes.NewReader(reqBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating TDX attestation request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("doing TDX attestation request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if msg, err := io.ReadAll(resp.Body); err == nil && len(msg) > 0 {
|
||||
return "", fmt.Errorf("MAA returned %v: %s", resp.Status, msg)
|
||||
}
|
||||
return "", fmt.Errorf("MAA returned %v", resp.Status)
|
||||
}
|
||||
|
||||
var attestResponse tdxAttestResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&attestResponse); err != nil {
|
||||
return "", fmt.Errorf("decoding TDX attestation response: %w", err)
|
||||
}
|
||||
if attestResponse.Token == "" {
|
||||
return "", fmt.Errorf("azure TDX attestation token not found in response")
|
||||
}
|
||||
|
||||
return attestResponse.Token, nil
|
||||
}
|
||||
|
||||
func (c *defaultAzureTDXIMDSClient) GetQuote(ctx context.Context, tdReport []byte, client *http.Client) ([]byte, error) {
|
||||
if len(tdReport) != azureTDReportSize {
|
||||
return nil, fmt.Errorf("invalid TD report length: expected %d bytes, got %d bytes", azureTDReportSize, len(tdReport))
|
||||
}
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
|
||||
quoteRequest := tdxIMDSQuoteRequest{
|
||||
Report: base64.RawURLEncoding.EncodeToString(tdReport),
|
||||
}
|
||||
|
||||
reqBytes, err := json.Marshal(quoteRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling Azure TDX IMDS quote request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, azureTDXIMDSQuoteURL, bytes.NewReader(reqBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating Azure TDX IMDS quote request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("doing Azure TDX IMDS quote request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if msg, err := io.ReadAll(resp.Body); err == nil && len(msg) > 0 {
|
||||
return nil, fmt.Errorf("Azure TDX IMDS returned %v: %s", resp.Status, msg)
|
||||
}
|
||||
return nil, fmt.Errorf("Azure TDX IMDS returned %v", resp.Status)
|
||||
}
|
||||
|
||||
var quoteResponse tdxIMDSQuoteResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode("eResponse); err != nil {
|
||||
return nil, fmt.Errorf("decoding Azure TDX IMDS quote response: %w", err)
|
||||
}
|
||||
if quoteResponse.Quote == "" {
|
||||
return nil, fmt.Errorf("Azure TDX IMDS quote not found in response")
|
||||
}
|
||||
|
||||
quote, err := decodeBase64URL(quoteResponse.Quote)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding Azure TDX IMDS quote: %w", err)
|
||||
}
|
||||
|
||||
return quote, nil
|
||||
}
|
||||
|
||||
// FetchAzureTDXAttestationToken fetches an Azure Attestation token for an Azure TDX VM.
|
||||
func FetchAzureTDXAttestationToken(tokenNonce []byte, maaURL string) ([]byte, error) {
|
||||
reportData := tdxReportDataFromRuntimeData(tokenNonce)
|
||||
evidence, err := fetchAzureTDXEvidence(reportData, tokenNonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Azure TDX quote: %w", err)
|
||||
}
|
||||
|
||||
token, err := DefaultAzureTDXClient.AttestTDXVM(context.Background(), evidence.Quote, evidence.RuntimeData, tokenNonce, maaURL, http.DefaultClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching azure TDX token: %w", err)
|
||||
}
|
||||
|
||||
return []byte(token), nil
|
||||
}
|
||||
|
||||
func tdxReportDataFromRuntimeData(runtimeData []byte) [tdxabi.ReportDataSize]byte {
|
||||
hash := sha256.Sum256(runtimeData)
|
||||
var reportData [tdxabi.ReportDataSize]byte
|
||||
copy(reportData[:sha256.Size], hash[:])
|
||||
return reportData
|
||||
}
|
||||
|
||||
func fetchAzureTDXEvidence(reportData [tdxabi.ReportDataSize]byte, fallbackRuntimeData []byte) (*azureTDXEvidence, error) {
|
||||
if evidenceFetcher, ok := DefaultTDXQuoteFetcher.(TDXEvidenceFetcher); ok {
|
||||
return evidenceFetcher.FetchEvidence(reportData)
|
||||
}
|
||||
|
||||
quote, err := DefaultTDXQuoteFetcher.FetchQuote(reportData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &azureTDXEvidence{
|
||||
Quote: quote,
|
||||
RuntimeData: append([]byte(nil), fallbackRuntimeData...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readFreshAzureTDXHCLReport(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
|
||||
if azureTDXReportDataWriter != nil {
|
||||
if err := azureTDXReportDataWriter(reportData[:]); err != nil {
|
||||
return nil, fmt.Errorf("writing Azure TDX report data: %w", err)
|
||||
}
|
||||
if azureTDXHCLRefreshDelay > 0 {
|
||||
time.Sleep(azureTDXHCLRefreshDelay)
|
||||
}
|
||||
}
|
||||
if azureTDXHCLReportReader == nil {
|
||||
return nil, fmt.Errorf("Azure TDX HCL report reader is not configured")
|
||||
}
|
||||
|
||||
hclReport, err := azureTDXHCLReportReader()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading Azure TDX HCL report: %w", err)
|
||||
}
|
||||
|
||||
return hclReport, nil
|
||||
}
|
||||
|
||||
func readAzureHCLReport() ([]byte, error) {
|
||||
tpm, err := tpm2.OpenTPM()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tpm.Close()
|
||||
|
||||
return tpm2.NVReadEx(tpm, azureHCLReportNVIndex, tpm2.HandleOwner, "", 0)
|
||||
}
|
||||
|
||||
func writeAzureTDXReportData(data []byte) error {
|
||||
if len(data) != tdxabi.ReportDataSize {
|
||||
return fmt.Errorf("invalid Azure TDX report data length: expected %d bytes, got %d bytes", tdxabi.ReportDataSize, len(data))
|
||||
}
|
||||
|
||||
tpm, err := tpm2.OpenTPM()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tpm.Close()
|
||||
|
||||
return writeAzureTDXReportDataToTPM(tpm, data)
|
||||
}
|
||||
|
||||
func writeAzureTDXReportDataToTPM(tpm io.ReadWriter, data []byte) error {
|
||||
if len(data) > int(^uint16(0)) {
|
||||
return fmt.Errorf("Azure TDX report data is too large")
|
||||
}
|
||||
|
||||
if err := ensureAzureTDXReportDataIndex(tpm, uint16(len(data))); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tpm2.NVWrite(tpm, tpm2.HandleOwner, azureHCLReportDataNVIndex, "", tpmutil.U16Bytes(data), 0); err != nil {
|
||||
return fmt.Errorf("writing Azure TDX report-data NV index: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureAzureTDXReportDataIndex(tpm io.ReadWriter, size uint16) error {
|
||||
pub, err := tpm2.NVReadPublic(tpm, azureHCLReportDataNVIndex)
|
||||
if err == nil {
|
||||
if pub.DataSize == size {
|
||||
return nil
|
||||
}
|
||||
if err := tpm2.NVUndefineSpace(tpm, "", tpm2.HandleOwner, azureHCLReportDataNVIndex); err != nil {
|
||||
return fmt.Errorf("undefining mismatched Azure TDX report-data NV index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
nvPub := tpm2.NVPublic{
|
||||
NVIndex: azureHCLReportDataNVIndex,
|
||||
NameAlg: tpm2.AlgSHA256,
|
||||
Attributes: tpm2.AttrOwnerWrite | tpm2.AttrOwnerRead,
|
||||
DataSize: size,
|
||||
}
|
||||
authArea := tpm2.AuthCommand{
|
||||
Session: tpm2.HandlePasswordSession,
|
||||
Attributes: tpm2.AttrContinueSession,
|
||||
Auth: []byte(""),
|
||||
}
|
||||
if err := tpm2.NVDefineSpaceEx(tpm, tpm2.HandleOwner, "", nvPub, authArea); err != nil {
|
||||
return fmt.Errorf("defining Azure TDX report-data NV index: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAzureHCLReport(report []byte) (*azureHCLReport, error) {
|
||||
minSize := azureHCLRuntimeDataOffset + azureHCLRuntimeClaimsOffset
|
||||
if len(report) < minSize {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report size: expected at least %d bytes, got %d bytes", minSize, len(report))
|
||||
}
|
||||
if string(report[:len(azureHCLSignature)]) != azureHCLSignature {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report signature")
|
||||
}
|
||||
if version := binary.LittleEndian.Uint32(report[4:8]); version != azureHCLVersion {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report version: expected %d, got %d", azureHCLVersion, version)
|
||||
}
|
||||
reportSize := binary.LittleEndian.Uint32(report[8:12])
|
||||
if reportSize > uint32(len(report)) {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report size: header reports %d bytes, got %d bytes", reportSize, len(report))
|
||||
}
|
||||
if requestType := binary.LittleEndian.Uint32(report[12:16]); requestType != azureHCLRequestType {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report request type: expected %d, got %d", azureHCLRequestType, requestType)
|
||||
}
|
||||
|
||||
runtimeData := report[azureHCLRuntimeDataOffset:]
|
||||
dataSize := binary.LittleEndian.Uint32(runtimeData[0:4])
|
||||
if dataSize < azureHCLRuntimeClaimsOffset {
|
||||
return nil, fmt.Errorf("invalid Azure HCL runtime data size: %d", dataSize)
|
||||
}
|
||||
if azureHCLRuntimeDataOffset+int(dataSize) > len(report) {
|
||||
return nil, fmt.Errorf("invalid Azure HCL runtime data size: header reports %d bytes", dataSize)
|
||||
}
|
||||
if version := binary.LittleEndian.Uint32(runtimeData[4:8]); version != azureHCLRuntimeDataVersion {
|
||||
return nil, fmt.Errorf("invalid Azure HCL runtime data version: expected %d, got %d", azureHCLRuntimeDataVersion, version)
|
||||
}
|
||||
|
||||
reportType := binary.LittleEndian.Uint32(runtimeData[8:12])
|
||||
if reportType != azureHCLReportTypeSNP && reportType != azureHCLReportTypeTDX {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report type: %d", reportType)
|
||||
}
|
||||
|
||||
hashType := binary.LittleEndian.Uint32(runtimeData[12:16])
|
||||
claimsSize := binary.LittleEndian.Uint32(runtimeData[16:20])
|
||||
claimsEnd := azureHCLRuntimeClaimsOffset + int(claimsSize)
|
||||
if claimsEnd > int(dataSize) {
|
||||
return nil, fmt.Errorf("invalid Azure HCL runtime claims size: %d", claimsSize)
|
||||
}
|
||||
|
||||
hwReportSize := azureHCLMaxHWReportSize
|
||||
if reportType == azureHCLReportTypeTDX {
|
||||
hwReportSize = azureTDReportSize
|
||||
}
|
||||
if azureHCLHeaderSize+hwReportSize > len(report) {
|
||||
return nil, fmt.Errorf("invalid Azure HCL hardware report size: %d", hwReportSize)
|
||||
}
|
||||
|
||||
return &azureHCLReport{
|
||||
reportType: reportType,
|
||||
hashType: hashType,
|
||||
hwReport: append([]byte(nil), report[azureHCLHeaderSize:azureHCLHeaderSize+hwReportSize]...),
|
||||
runtimeClaims: append([]byte(nil), runtimeData[azureHCLRuntimeClaimsOffset:claimsEnd]...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateAzureTDXRuntimeClaimsHash(report *azureHCLReport) error {
|
||||
if report.hashType != azureHCLHashSHA256 {
|
||||
return fmt.Errorf("unsupported Azure HCL runtime data hash type: %d", report.hashType)
|
||||
}
|
||||
if len(report.hwReport) < azureTDReportDataOffset+sha256.Size {
|
||||
return fmt.Errorf("invalid Azure TDX TD report size: %d", len(report.hwReport))
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(report.runtimeClaims)
|
||||
if !bytes.Equal(hash[:], report.hwReport[azureTDReportDataOffset:azureTDReportDataOffset+sha256.Size]) {
|
||||
return fmt.Errorf("Azure TDX runtime claims hash does not match TD report data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func tdxRuntimeDataType(runtimeData []byte) string {
|
||||
if json.Valid(runtimeData) {
|
||||
return tdxRuntimeJSON
|
||||
}
|
||||
return tdxRuntimeBinary
|
||||
}
|
||||
|
||||
func decodeBase64URL(value string) ([]byte, error) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
return base64.URLEncoding.DecodeString(value)
|
||||
}
|
||||
|
||||
func verifyTDXQuoteWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
|
||||
decodedQuote, err := tdxabi.QuoteToProto(report)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse TDX quote: %w", err)
|
||||
}
|
||||
|
||||
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported TDX quote format")
|
||||
}
|
||||
|
||||
tdReport := quoteV4.GetTdQuoteBody()
|
||||
if tdReport == nil {
|
||||
return fmt.Errorf("missing TDX quote body")
|
||||
}
|
||||
|
||||
mrtd := tdReport.GetMrTd()
|
||||
if len(mrtd) == 0 {
|
||||
return fmt.Errorf("no MRTD in TDX quote")
|
||||
}
|
||||
|
||||
if err := matchMeasurementInCoRIM(manifest, mrtd); err != nil {
|
||||
return fmt.Errorf("%w for Azure TDX", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func matchMeasurementInCoRIM(manifest *corim.UnsignedCorim, measurement []byte) error {
|
||||
if manifest == nil || len(manifest.Tags) == 0 {
|
||||
return fmt.Errorf("no tags in CoRIM")
|
||||
}
|
||||
|
||||
for _, tag := range manifest.Tags {
|
||||
if !bytes.HasPrefix(tag, corim.ComidTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
tagValue := tag[len(corim.ComidTag):]
|
||||
|
||||
var c comid.Comid
|
||||
if err := c.FromCBOR(tagValue); err != nil {
|
||||
return fmt.Errorf("failed to parse CoMID: %w", err)
|
||||
}
|
||||
|
||||
if c.Triples.ReferenceValues == nil {
|
||||
continue
|
||||
}
|
||||
for _, rv := range *c.Triples.ReferenceValues {
|
||||
for _, m := range rv.Measurements {
|
||||
if m.Val.Digests == nil {
|
||||
continue
|
||||
}
|
||||
for _, digest := range *m.Val.Digests {
|
||||
if bytes.Equal(digest.HashValue, measurement) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("no matching reference value found in CoRIM")
|
||||
}
|
||||
|
||||
// AzureTDXMeasurementData contains the fields extracted from an Azure TDX attestation token
|
||||
// needed to construct a CoRIM policy for the TDX platform.
|
||||
type AzureTDXMeasurementData struct {
|
||||
MRTD string
|
||||
MRSEAM string
|
||||
RTMRs []string
|
||||
SEAMSVN uint64
|
||||
}
|
||||
|
||||
// ExtractAzureTDXMeasurement extracts core TDX measurements from an Azure Attestation token.
|
||||
func ExtractAzureTDXMeasurement(token string) (*AzureTDXMeasurementData, error) {
|
||||
claims, err := DefaultValidator.Validate(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate token: %w", err)
|
||||
}
|
||||
|
||||
mrtd, ok := azureClaimString(claims, "tdx_mrtd")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to get MRTD from claims")
|
||||
}
|
||||
|
||||
mrSeam, _ := azureClaimString(claims, "tdx_mrseam")
|
||||
|
||||
rtmrs := make([]string, 0, 4)
|
||||
for _, name := range []string{"tdx_rtmr0", "tdx_rtmr1", "tdx_rtmr2", "tdx_rtmr3"} {
|
||||
if value, ok := azureClaimString(claims, name); ok {
|
||||
rtmrs = append(rtmrs, value)
|
||||
}
|
||||
}
|
||||
|
||||
seamSVN, _ := azureClaimUint64(claims, "tdx_seamsvn")
|
||||
|
||||
return &AzureTDXMeasurementData{
|
||||
MRTD: mrtd,
|
||||
MRSEAM: mrSeam,
|
||||
RTMRs: rtmrs,
|
||||
SEAMSVN: seamSVN,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func azureClaimString(claims map[string]any, name string) (string, bool) {
|
||||
if value, ok := claims[name].(string); ok {
|
||||
return value, true
|
||||
}
|
||||
|
||||
tee, ok := claims["x-ms-isolation-tee"].(map[string]any)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
value, ok := tee[name].(string)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func azureClaimUint64(claims map[string]any, name string) (uint64, bool) {
|
||||
value, ok := claims[name]
|
||||
if !ok {
|
||||
tee, teeOK := claims["x-ms-isolation-tee"].(map[string]any)
|
||||
if !teeOK {
|
||||
return 0, false
|
||||
}
|
||||
value, ok = tee[name]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return uint64(typed), true
|
||||
case int:
|
||||
return uint64(typed), true
|
||||
case uint64:
|
||||
return typed, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,455 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
tdxabi "github.com/google/go-tdx-guest/abi"
|
||||
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
|
||||
tdxtestdata "github.com/google/go-tdx-guest/testing/testdata"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/veraison/corim/comid"
|
||||
"github.com/veraison/corim/corim"
|
||||
"github.com/veraison/swid"
|
||||
)
|
||||
|
||||
type mockTDXQuoteFetcher struct {
|
||||
quote []byte
|
||||
err error
|
||||
gotReportData [tdxabi.ReportDataSize]byte
|
||||
fetchQuoteCall bool
|
||||
}
|
||||
|
||||
func (m *mockTDXQuoteFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
|
||||
m.fetchQuoteCall = true
|
||||
m.gotReportData = reportData
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.quote, nil
|
||||
}
|
||||
|
||||
type mockTDXEvidenceFetcher struct {
|
||||
evidence *azureTDXEvidence
|
||||
err error
|
||||
gotReportData [tdxabi.ReportDataSize]byte
|
||||
fetchEvidenceHit bool
|
||||
}
|
||||
|
||||
func (m *mockTDXEvidenceFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
|
||||
evidence, err := m.FetchEvidence(reportData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return evidence.Quote, nil
|
||||
}
|
||||
|
||||
func (m *mockTDXEvidenceFetcher) FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error) {
|
||||
m.fetchEvidenceHit = true
|
||||
m.gotReportData = reportData
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.evidence, nil
|
||||
}
|
||||
|
||||
type mockAzureTDXClient struct {
|
||||
token string
|
||||
err error
|
||||
gotQuote []byte
|
||||
gotRuntime []byte
|
||||
gotNonce []byte
|
||||
gotMaaURL string
|
||||
attestCalls int
|
||||
}
|
||||
|
||||
func (m *mockAzureTDXClient) AttestTDXVM(_ context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, _ *http.Client) (string, error) {
|
||||
m.attestCalls++
|
||||
m.gotQuote = append([]byte(nil), quote...)
|
||||
m.gotRuntime = append([]byte(nil), runtimeData...)
|
||||
m.gotNonce = append([]byte(nil), nonce...)
|
||||
m.gotMaaURL = maaURL
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
return m.token, nil
|
||||
}
|
||||
|
||||
type mockAzureTDXIMDSClient struct {
|
||||
quote []byte
|
||||
err error
|
||||
gotTDReport []byte
|
||||
getQuoteCall bool
|
||||
}
|
||||
|
||||
func (m *mockAzureTDXIMDSClient) GetQuote(_ context.Context, tdReport []byte, _ *http.Client) ([]byte, error) {
|
||||
m.getQuoteCall = true
|
||||
m.gotTDReport = append([]byte(nil), tdReport...)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.quote, nil
|
||||
}
|
||||
|
||||
func testAzureHCLReport(reportType uint32, runtimeClaims []byte) []byte {
|
||||
reportSize := azureHCLRuntimeDataOffset + azureHCLRuntimeClaimsOffset + len(runtimeClaims)
|
||||
hclReport := make([]byte, reportSize)
|
||||
copy(hclReport[:len(azureHCLSignature)], azureHCLSignature)
|
||||
binary.LittleEndian.PutUint32(hclReport[4:8], azureHCLVersion)
|
||||
binary.LittleEndian.PutUint32(hclReport[8:12], uint32(reportSize))
|
||||
binary.LittleEndian.PutUint32(hclReport[12:16], azureHCLRequestType)
|
||||
|
||||
runtimeData := hclReport[azureHCLRuntimeDataOffset:]
|
||||
binary.LittleEndian.PutUint32(runtimeData[0:4], uint32(azureHCLRuntimeClaimsOffset+len(runtimeClaims)))
|
||||
binary.LittleEndian.PutUint32(runtimeData[4:8], azureHCLRuntimeDataVersion)
|
||||
binary.LittleEndian.PutUint32(runtimeData[8:12], reportType)
|
||||
binary.LittleEndian.PutUint32(runtimeData[12:16], azureHCLHashSHA256)
|
||||
binary.LittleEndian.PutUint32(runtimeData[16:20], uint32(len(runtimeClaims)))
|
||||
copy(runtimeData[azureHCLRuntimeClaimsOffset:], runtimeClaims)
|
||||
|
||||
if reportType == azureHCLReportTypeTDX {
|
||||
hash := sha256.Sum256(runtimeClaims)
|
||||
copy(hclReport[azureHCLHeaderSize+azureTDReportDataOffset:], hash[:])
|
||||
}
|
||||
|
||||
return hclReport
|
||||
}
|
||||
|
||||
func TestProvider_TeeAttestation_AzureTDX(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
oldFetcher := DefaultTDXQuoteFetcher
|
||||
defer func() {
|
||||
azureTDXHCLReportReader = oldReader
|
||||
DefaultTDXQuoteFetcher = oldFetcher
|
||||
}()
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
|
||||
}
|
||||
fetcher := &mockTDXQuoteFetcher{quote: []byte("tdx-quote")}
|
||||
DefaultTDXQuoteFetcher = fetcher
|
||||
|
||||
reportData := bytes.Repeat([]byte{0xAB}, tdxabi.ReportDataSize)
|
||||
got, err := NewProvider().TeeAttestation(reportData)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("tdx-quote"), got)
|
||||
assert.True(t, fetcher.fetchQuoteCall)
|
||||
assert.Equal(t, reportData, fetcher.gotReportData[:])
|
||||
}
|
||||
|
||||
func TestProvider_TeeAttestation_AzureTDX_InvalidNonce(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
defer func() { azureTDXHCLReportReader = oldReader }()
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
|
||||
}
|
||||
|
||||
_, err := NewProvider().TeeAttestation([]byte("short"))
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid tee nonce length")
|
||||
}
|
||||
|
||||
func TestProvider_AzureAttestationToken_AzureTDX(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
oldFetcher := DefaultTDXQuoteFetcher
|
||||
oldClient := DefaultAzureTDXClient
|
||||
oldMaaURL := MaaURL
|
||||
defer func() {
|
||||
azureTDXHCLReportReader = oldReader
|
||||
DefaultTDXQuoteFetcher = oldFetcher
|
||||
DefaultAzureTDXClient = oldClient
|
||||
MaaURL = oldMaaURL
|
||||
}()
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
|
||||
}
|
||||
MaaURL = "https://tdx.example.attest.azure.net"
|
||||
|
||||
fetcher := &mockTDXQuoteFetcher{quote: []byte("quote")}
|
||||
client := &mockAzureTDXClient{token: "tdx-token"}
|
||||
DefaultTDXQuoteFetcher = fetcher
|
||||
DefaultAzureTDXClient = client
|
||||
|
||||
nonce := []byte("token-nonce")
|
||||
got, err := NewProvider().AzureAttestationToken(nonce)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("tdx-token"), got)
|
||||
|
||||
expectedReportData := tdxReportDataFromRuntimeData(nonce)
|
||||
assert.Equal(t, expectedReportData, fetcher.gotReportData)
|
||||
assert.Equal(t, []byte("quote"), client.gotQuote)
|
||||
assert.Equal(t, nonce, client.gotRuntime)
|
||||
assert.Equal(t, nonce, client.gotNonce)
|
||||
assert.Equal(t, MaaURL, client.gotMaaURL)
|
||||
}
|
||||
|
||||
func TestFetchAzureTDXAttestationToken_UsesHCLRuntimeClaims(t *testing.T) {
|
||||
oldFetcher := DefaultTDXQuoteFetcher
|
||||
oldClient := DefaultAzureTDXClient
|
||||
defer func() {
|
||||
DefaultTDXQuoteFetcher = oldFetcher
|
||||
DefaultAzureTDXClient = oldClient
|
||||
}()
|
||||
|
||||
runtimeClaims := []byte(`{"keys":[],"user-data":"nonce"}`)
|
||||
fetcher := &mockTDXEvidenceFetcher{
|
||||
evidence: &azureTDXEvidence{
|
||||
Quote: []byte("quote"),
|
||||
RuntimeData: runtimeClaims,
|
||||
},
|
||||
}
|
||||
client := &mockAzureTDXClient{token: "tdx-token"}
|
||||
DefaultTDXQuoteFetcher = fetcher
|
||||
DefaultAzureTDXClient = client
|
||||
|
||||
nonce := []byte("token-nonce")
|
||||
got, err := FetchAzureTDXAttestationToken(nonce, "https://tdx.example.attest.azure.net")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("tdx-token"), got)
|
||||
assert.True(t, fetcher.fetchEvidenceHit)
|
||||
assert.Equal(t, tdxReportDataFromRuntimeData(nonce), fetcher.gotReportData)
|
||||
assert.Equal(t, runtimeClaims, client.gotRuntime)
|
||||
assert.Equal(t, nonce, client.gotNonce)
|
||||
}
|
||||
|
||||
func TestFetchAzureTDXAttestationToken_FetchQuoteError(t *testing.T) {
|
||||
oldFetcher := DefaultTDXQuoteFetcher
|
||||
defer func() { DefaultTDXQuoteFetcher = oldFetcher }()
|
||||
|
||||
DefaultTDXQuoteFetcher = &mockTDXQuoteFetcher{err: fmt.Errorf("quote unavailable")}
|
||||
|
||||
_, err := FetchAzureTDXAttestationToken([]byte("nonce"), "https://tdx.example.attest.azure.net")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to fetch Azure TDX quote")
|
||||
}
|
||||
|
||||
func TestDefaultTDXQuoteFetcher_FetchEvidence_AzureHCLIMDS(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
oldWriter := azureTDXReportDataWriter
|
||||
oldDelay := azureTDXHCLRefreshDelay
|
||||
oldIMDSClient := DefaultAzureTDXIMDSClient
|
||||
defer func() {
|
||||
azureTDXHCLReportReader = oldReader
|
||||
azureTDXReportDataWriter = oldWriter
|
||||
azureTDXHCLRefreshDelay = oldDelay
|
||||
DefaultAzureTDXIMDSClient = oldIMDSClient
|
||||
}()
|
||||
|
||||
runtimeClaims := []byte(`{"keys":[],"user-data":"fresh"}`)
|
||||
hclReport := testAzureHCLReport(azureHCLReportTypeTDX, runtimeClaims)
|
||||
var gotReportData []byte
|
||||
azureTDXReportDataWriter = func(data []byte) error {
|
||||
gotReportData = append([]byte(nil), data...)
|
||||
return nil
|
||||
}
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return hclReport, nil
|
||||
}
|
||||
azureTDXHCLRefreshDelay = 0
|
||||
imdsClient := &mockAzureTDXIMDSClient{quote: []byte("tdx-quote")}
|
||||
DefaultAzureTDXIMDSClient = imdsClient
|
||||
|
||||
reportData := [tdxabi.ReportDataSize]byte{0xAB}
|
||||
evidence, err := defaultTDXQuoteFetcher{}.FetchEvidence(reportData)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, reportData[:], gotReportData)
|
||||
assert.Equal(t, []byte("tdx-quote"), evidence.Quote)
|
||||
assert.Equal(t, runtimeClaims, evidence.RuntimeData)
|
||||
assert.True(t, imdsClient.getQuoteCall)
|
||||
assert.Len(t, imdsClient.gotTDReport, azureTDReportSize)
|
||||
}
|
||||
|
||||
func TestDefaultAzureTDXClient_AttestTDXVM(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "/attest/TdxVm", r.URL.Path)
|
||||
assert.Equal(t, tdxAPIVersion, r.URL.Query().Get("api-version"))
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
|
||||
var req tdxAttestRequest
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("quote")), req.Quote)
|
||||
require.NotNil(t, req.RuntimeData)
|
||||
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("runtime")), req.RuntimeData.Data)
|
||||
assert.Equal(t, tdxRuntimeBinary, req.RuntimeData.DataType)
|
||||
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("nonce")), req.Nonce)
|
||||
|
||||
_, _ = w.Write([]byte(`{"token":"tdx-token"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
token, err := (&defaultAzureTDXClient{}).AttestTDXVM(
|
||||
context.Background(),
|
||||
[]byte("quote"),
|
||||
[]byte("runtime"),
|
||||
[]byte("nonce"),
|
||||
server.URL,
|
||||
server.Client(),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "tdx-token", token)
|
||||
}
|
||||
|
||||
func TestDefaultAzureTDXClient_AttestTDXVM_JSONRuntimeData(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req tdxAttestRequest
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
require.NotNil(t, req.RuntimeData)
|
||||
assert.Equal(t, tdxRuntimeJSON, req.RuntimeData.DataType)
|
||||
|
||||
_, _ = w.Write([]byte(`{"token":"tdx-token"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := (&defaultAzureTDXClient{}).AttestTDXVM(
|
||||
context.Background(),
|
||||
[]byte("quote"),
|
||||
[]byte(`{"keys":[]}`),
|
||||
[]byte("nonce"),
|
||||
server.URL,
|
||||
server.Client(),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDefaultAzureTDXClient_AttestTDXVM_ErrorStatus(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad quote", http.StatusBadRequest)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := (&defaultAzureTDXClient{}).AttestTDXVM(context.Background(), []byte("quote"), nil, nil, server.URL, server.Client())
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "MAA returned 400 Bad Request")
|
||||
}
|
||||
|
||||
func TestDefaultAzureTDXIMDSClient_GetQuote(t *testing.T) {
|
||||
oldURL := azureTDXIMDSQuoteURL
|
||||
defer func() { azureTDXIMDSQuoteURL = oldURL }()
|
||||
|
||||
tdReport := bytes.Repeat([]byte{0xA5}, azureTDReportSize)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
|
||||
var req tdxIMDSQuoteRequest
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, base64.RawURLEncoding.EncodeToString(tdReport), req.Report)
|
||||
|
||||
resp := tdxIMDSQuoteResponse{
|
||||
Quote: base64.RawURLEncoding.EncodeToString([]byte("quote")),
|
||||
}
|
||||
require.NoError(t, json.NewEncoder(w).Encode(resp))
|
||||
}))
|
||||
defer server.Close()
|
||||
azureTDXIMDSQuoteURL = server.URL
|
||||
|
||||
quote, err := (&defaultAzureTDXIMDSClient{}).GetQuote(context.Background(), tdReport, server.Client())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("quote"), quote)
|
||||
}
|
||||
|
||||
func TestIsAzureTDX_UsesHCLReportType(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
defer func() { azureTDXHCLReportReader = oldReader }()
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeSNP, []byte("runtime")), nil
|
||||
}
|
||||
assert.False(t, isAzureTDX())
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
|
||||
}
|
||||
assert.True(t, isAzureTDX())
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return nil, fmt.Errorf("no vTPM")
|
||||
}
|
||||
assert.False(t, isAzureTDX())
|
||||
}
|
||||
|
||||
func TestExtractAzureTDXMeasurement_Success(t *testing.T) {
|
||||
oldValidator := DefaultValidator
|
||||
defer func() { DefaultValidator = oldValidator }()
|
||||
|
||||
DefaultValidator = &mockTokenValidator{
|
||||
validateFunc: func(token string) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"tdx_mrtd": "mrtd",
|
||||
"tdx_mrseam": "mrseam",
|
||||
"tdx_rtmr0": "rtmr0",
|
||||
"tdx_rtmr1": "rtmr1",
|
||||
"tdx_rtmr2": "rtmr2",
|
||||
"tdx_rtmr3": "rtmr3",
|
||||
"tdx_seamsvn": float64(7),
|
||||
"unrelatedKey": "ignored",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
data, err := ExtractAzureTDXMeasurement("valid-token")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &AzureTDXMeasurementData{
|
||||
MRTD: "mrtd",
|
||||
MRSEAM: "mrseam",
|
||||
RTMRs: []string{"rtmr0", "rtmr1", "rtmr2", "rtmr3"},
|
||||
SEAMSVN: 7,
|
||||
}, data)
|
||||
}
|
||||
|
||||
func TestVerifier_VerifyWithCoRIM_AzureTDX(t *testing.T) {
|
||||
decodedQuote, err := tdxabi.QuoteToProto(tdxtestdata.RawQuote)
|
||||
require.NoError(t, err)
|
||||
|
||||
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
|
||||
require.True(t, ok)
|
||||
mrtd := quoteV4.GetTdQuoteBody().GetMrTd()
|
||||
|
||||
c := comid.NewComid()
|
||||
c.SetTagIdentity("tdx-tag", 0)
|
||||
|
||||
m := comid.MustNewUintMeasurement(uint64(1))
|
||||
m.AddDigest(swid.Sha384, mrtd)
|
||||
m.SetRawValueBytes([]byte("raw"), nil)
|
||||
|
||||
rv := comid.ReferenceValue{
|
||||
Environment: comid.Environment{
|
||||
Class: comid.NewClassOID("1.2.3.4"),
|
||||
},
|
||||
Measurements: comid.Measurements{*m},
|
||||
}
|
||||
c.AddReferenceValue(rv)
|
||||
|
||||
manifest := corim.NewUnsignedCorim()
|
||||
manifest.SetID("test-tdx-corim")
|
||||
manifest.AddComid(*c)
|
||||
|
||||
err = NewVerifier(&bytes.Buffer{}).VerifyWithCoRIM(tdxtestdata.RawQuote, manifest)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -55,6 +55,8 @@ func (c *client) GetAttestation(ctx context.Context, reportData [64]byte, nonce
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
|
||||
case attestation.SNPvTPM:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
|
||||
case attestation.Azure:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_AZURE
|
||||
default:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
|
||||
}
|
||||
@@ -92,6 +94,8 @@ func (c *client) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
|
||||
case attestation.SNPvTPM:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
|
||||
case attestation.Azure:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_AZURE
|
||||
default:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
|
||||
}
|
||||
|
||||
@@ -174,6 +174,40 @@ func TestGetAttestationTDX(t *testing.T) {
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetAttestationAzure tests getting Azure attestation.
|
||||
func TestGetAttestationAzure(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-azure-evidence.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.Azure)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, quote)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_AZURE, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetAttestationVTPM tests getting vTPM attestation.
|
||||
func TestGetAttestationVTPM(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
@@ -479,6 +513,40 @@ func TestGetRawEvidenceTDX(t *testing.T) {
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetRawEvidenceAzure tests getting raw evidence for Azure platform.
|
||||
func TestGetRawEvidenceAzure(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "raw-evidence-azure.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.Azure)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, evidence)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_AZURE, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetRawEvidenceVTPM tests getting raw evidence for VTPM platform.
|
||||
func TestGetRawEvidenceVTPM(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
Reference in New Issue
Block a user