NOISSUE - Azure TDX Support (#596)

* initial Azure TDX support

* add tests

* update documentation

---------

Co-authored-by: Ubuntu <danko@cocos.nbzvzgavv4yeximq0jorvcggfd.dx.internal.cloudapp.net>
This commit is contained in:
Danko Miladinovic
2026-05-25 12:22:29 +02:00
committed by GitHub
parent 27db9b29eb
commit 02aa7d7d85
11 changed files with 1302 additions and 3 deletions
+25
View File
@@ -4,6 +4,10 @@
package azure
import (
"os"
"strings"
"time"
"github.com/edgelesssys/go-azguestattestation/maa"
)
@@ -28,6 +32,27 @@ func InitializeDefaultMAAVars(config *EnvConfig) {
maa.OSType = config.OSType
maa.OSDistro = config.OSDistro
MaaURL = config.MaaURL
InitializeDefaultAzureTDXVarsFromEnv()
}
func InitializeDefaultAzureTDXVars(imdsURL string, hclRefreshDelay time.Duration) {
if imdsURL = strings.TrimSpace(imdsURL); imdsURL != "" {
azureTDXIMDSQuoteURL = imdsURL
}
if hclRefreshDelay >= 0 {
azureTDXHCLRefreshDelay = hclRefreshDelay
}
}
func InitializeDefaultAzureTDXVarsFromEnv() {
imdsURL := os.Getenv("AZURE_TDX_IMDS_URL")
hclRefreshDelay := azureTDXHCLRefreshDelay
if value := strings.TrimSpace(os.Getenv("AZURE_HCL_REFRESH_WAIT")); value != "" {
if parsed, err := time.ParseDuration(value); err == nil {
hclRefreshDelay = parsed
}
}
InitializeDefaultAzureTDXVars(imdsURL, hclRefreshDelay)
}
func (c *EnvConfig) InitializeOSVars(build, osType, osDistro string) {
+32
View File
@@ -5,6 +5,7 @@ package azure
import (
"testing"
"time"
"github.com/edgelesssys/go-azguestattestation/maa"
"github.com/stretchr/testify/assert"
@@ -56,3 +57,34 @@ func TestInitializeOSVars(t *testing.T) {
assert.Equal(t, "TypeY", cfg.OSType)
assert.Equal(t, "DistroZ", cfg.OSDistro)
}
func TestInitializeDefaultAzureTDXVars(t *testing.T) {
oldURL := azureTDXIMDSQuoteURL
oldDelay := azureTDXHCLRefreshDelay
defer func() {
azureTDXIMDSQuoteURL = oldURL
azureTDXHCLRefreshDelay = oldDelay
}()
InitializeDefaultAzureTDXVars(" https://imds.example/tdquote ", 1500*time.Millisecond)
assert.Equal(t, "https://imds.example/tdquote", azureTDXIMDSQuoteURL)
assert.Equal(t, 1500*time.Millisecond, azureTDXHCLRefreshDelay)
}
func TestInitializeDefaultAzureTDXVarsFromEnv(t *testing.T) {
oldURL := azureTDXIMDSQuoteURL
oldDelay := azureTDXHCLRefreshDelay
defer func() {
azureTDXIMDSQuoteURL = oldURL
azureTDXHCLRefreshDelay = oldDelay
}()
t.Setenv("AZURE_TDX_IMDS_URL", "https://env-imds.example/tdquote")
t.Setenv("AZURE_HCL_REFRESH_WAIT", "2s")
InitializeDefaultAzureTDXVarsFromEnv()
assert.Equal(t, "https://env-imds.example/tdquote", azureTDXIMDSQuoteURL)
assert.Equal(t, 2*time.Second, azureTDXHCLRefreshDelay)
}
+24 -2
View File
@@ -52,6 +52,10 @@ func NewProvider() attestation.Provider {
}
func (a provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error) {
if isAzureTDX() {
return a.TeeAttestation(teeNonce)
}
var tokenNonce [vtpm.Nonce]byte
copy(tokenNonce[:], teeNonce)
@@ -77,6 +81,10 @@ func (a provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
}
func (a provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
if isAzureTDX() {
return fetchAzureTDXQuote(teeNonce)
}
var tokenNonce [vtpm.Nonce]byte
copy(tokenNonce[:], teeNonce)
@@ -89,7 +97,6 @@ func (a provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
}
func (a provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
fmt.Printf("DEBUG: VTpmAttestation: vtpm.ExternalTPM is %T at %p\n", vtpm.ExternalTPM, &vtpm.ExternalTPM)
quote, err := vtpm.FetchQuote(vTpmNonce)
if err != nil {
return []byte{}, errors.Wrap(vtpm.ErrFetchQuote, err)
@@ -111,6 +118,14 @@ func (c *defaultMaaClient) Attest(ctx context.Context, nonce []byte, maaURL stri
var DefaultMaaClient MaaClient = &defaultMaaClient{}
func (a provider) AzureAttestationToken(tokenNonce []byte) ([]byte, error) {
if isAzureTDX() {
token, err := FetchAzureTDXAttestationToken(tokenNonce, MaaURL)
if err != nil {
return nil, errors.Wrap(ErrFetchAzureToken, err)
}
return token, nil
}
token, err := DefaultMaaClient.Attest(context.Background(), tokenNonce, MaaURL, http.DefaultClient)
if err != nil {
return nil, errors.Wrap(ErrFetchAzureToken, err)
@@ -152,12 +167,19 @@ func (v verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte)
func (v verifier) VerifyWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
attestation := &attest.Attestation{}
if err := proto.Unmarshal(report, attestation); err != nil {
return fmt.Errorf("failed to unmarshal attestation report: %w", err)
tdxErr := verifyTDXQuoteWithCoRIM(report, manifest)
if tdxErr == nil {
return nil
}
return fmt.Errorf("failed to unmarshal attestation report: %w; Azure TDX verification failed: %v", err, tdxErr)
}
// Extract measurement from SEV-SNP report if present
snpRep := attestation.GetSevSnpAttestation()
if snpRep == nil {
if tdxErr := verifyTDXQuoteWithCoRIM(report, manifest); tdxErr == nil {
return nil
}
return fmt.Errorf("no SEV-SNP attestation found in report")
}
+672
View File
@@ -0,0 +1,672 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package azure
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
tdxabi "github.com/google/go-tdx-guest/abi"
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
"github.com/veraison/corim/comid"
"github.com/veraison/corim/corim"
)
const (
tdxAttestEndpoint = "attest/TdxVm"
tdxAPIVersion = "2025-06-01"
tdxRuntimeBinary = "Binary"
tdxRuntimeJSON = "JSON"
azureHCLReportNVIndex = 0x01400001
azureHCLReportDataNVIndex = 0x01400002
azureHCLSignature = "HCLA"
azureHCLVersion = 2
azureHCLRequestType = 2
azureHCLRuntimeDataVersion = 1
azureHCLHashSHA256 = 1
azureHCLReportTypeSNP = 2
azureHCLReportTypeTDX = 4
azureHCLHeaderSize = 0x20
azureHCLMaxHWReportSize = 0x4a0
azureHCLRuntimeDataOffset = azureHCLHeaderSize + azureHCLMaxHWReportSize
azureHCLRuntimeClaimsOffset = 0x14
azureTDReportSize = 0x400
azureTDReportDataOffset = 0x80
)
var (
azureTDXHCLReportReader = readAzureHCLReport
azureTDXReportDataWriter = writeAzureTDXReportData
azureTDXHCLRefreshDelay = 3 * time.Second
azureTDXIMDSQuoteURL = "http://169.254.169.254/acc/tdquote"
)
// TDXQuoteFetcher fetches a raw TDX quote for the provided REPORT_DATA.
type TDXQuoteFetcher interface {
FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error)
}
type TDXEvidenceFetcher interface {
FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error)
}
type defaultTDXQuoteFetcher struct{}
func (f defaultTDXQuoteFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
evidence, err := f.FetchEvidence(reportData)
if err != nil {
return nil, err
}
return evidence.Quote, nil
}
func (f defaultTDXQuoteFetcher) FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error) {
hclReport, err := readFreshAzureTDXHCLReport(reportData)
if err != nil {
return nil, err
}
parsedReport, err := parseAzureHCLReport(hclReport)
if err != nil {
return nil, err
}
if parsedReport.reportType != azureHCLReportTypeTDX {
return nil, fmt.Errorf("Azure HCL report is not TDX")
}
if err := validateAzureTDXRuntimeClaimsHash(parsedReport); err != nil {
return nil, err
}
quote, err := DefaultAzureTDXIMDSClient.GetQuote(context.Background(), parsedReport.hwReport, http.DefaultClient)
if err != nil {
return nil, fmt.Errorf("failed to get Azure TDX quote from IMDS: %w", err)
}
return &azureTDXEvidence{
Quote: quote,
RuntimeData: append([]byte(nil), parsedReport.runtimeClaims...),
}, nil
}
// DefaultTDXQuoteFetcher is used by the Azure TDX provider and is replaceable in tests.
var DefaultTDXQuoteFetcher TDXQuoteFetcher = defaultTDXQuoteFetcher{}
// AzureTDXIMDSClient fetches an Azure TDX quote from the Azure Instance Metadata Service.
type AzureTDXIMDSClient interface {
GetQuote(ctx context.Context, tdReport []byte, client *http.Client) ([]byte, error)
}
type defaultAzureTDXIMDSClient struct{}
// DefaultAzureTDXIMDSClient is used by the Azure TDX quote fetcher and is replaceable in tests.
var DefaultAzureTDXIMDSClient AzureTDXIMDSClient = &defaultAzureTDXIMDSClient{}
// AzureTDXClient submits Azure TDX VM attestation evidence to Microsoft Azure Attestation.
type AzureTDXClient interface {
AttestTDXVM(ctx context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, client *http.Client) (string, error)
}
type defaultAzureTDXClient struct{}
// DefaultAzureTDXClient is used by Azure TDX token fetching and is replaceable in tests.
var DefaultAzureTDXClient AzureTDXClient = &defaultAzureTDXClient{}
type tdxAttestRequest struct {
Quote string `json:"quote"`
RuntimeData *tdxDataBlob `json:"runtimeData,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
type tdxDataBlob struct {
Data string `json:"data"`
DataType string `json:"dataType"`
}
type tdxAttestResponse struct {
Token string `json:"token"`
}
type tdxIMDSQuoteRequest struct {
Report string `json:"report"`
}
type tdxIMDSQuoteResponse struct {
Quote string `json:"quote"`
}
type azureTDXEvidence struct {
Quote []byte
RuntimeData []byte
}
type azureHCLReport struct {
reportType uint32
hashType uint32
hwReport []byte
runtimeClaims []byte
}
func isAzureTDX() bool {
if azureTDXHCLReportReader == nil {
return false
}
hclReport, err := azureTDXHCLReportReader()
if err != nil {
return false
}
parsedReport, err := parseAzureHCLReport(hclReport)
return err == nil && parsedReport.reportType == azureHCLReportTypeTDX
}
func fetchAzureTDXQuote(teeNonce []byte) ([]byte, error) {
if teeNonce == nil {
return nil, fmt.Errorf("tee nonce is required for Azure TDX attestation")
}
if len(teeNonce) != tdxabi.ReportDataSize {
return nil, fmt.Errorf("invalid tee nonce length: expected %d bytes, got %d bytes", tdxabi.ReportDataSize, len(teeNonce))
}
var reportData [tdxabi.ReportDataSize]byte
copy(reportData[:], teeNonce)
evidence, err := fetchAzureTDXEvidence(reportData, teeNonce)
if err != nil {
return nil, err
}
return evidence.Quote, nil
}
func (c *defaultAzureTDXClient) AttestTDXVM(ctx context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, client *http.Client) (string, error) {
if maaURL == "" {
return "", fmt.Errorf("maaURL is empty")
}
if client == nil {
client = http.DefaultClient
}
maaURL, err := url.JoinPath(maaURL, tdxAttestEndpoint)
if err != nil {
return "", fmt.Errorf("parsing maaURL: %w", err)
}
maaURL += fmt.Sprintf("?api-version=%s", tdxAPIVersion)
attestRequest := tdxAttestRequest{
Quote: base64.RawURLEncoding.EncodeToString(quote),
}
if len(runtimeData) > 0 {
attestRequest.RuntimeData = &tdxDataBlob{
Data: base64.RawURLEncoding.EncodeToString(runtimeData),
DataType: tdxRuntimeDataType(runtimeData),
}
}
if len(nonce) > 0 {
attestRequest.Nonce = base64.RawURLEncoding.EncodeToString(nonce)
}
reqBytes, err := json.Marshal(attestRequest)
if err != nil {
return "", fmt.Errorf("marshaling TDX attestation request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, maaURL, bytes.NewReader(reqBytes))
if err != nil {
return "", fmt.Errorf("creating TDX attestation request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("doing TDX attestation request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if msg, err := io.ReadAll(resp.Body); err == nil && len(msg) > 0 {
return "", fmt.Errorf("MAA returned %v: %s", resp.Status, msg)
}
return "", fmt.Errorf("MAA returned %v", resp.Status)
}
var attestResponse tdxAttestResponse
if err := json.NewDecoder(resp.Body).Decode(&attestResponse); err != nil {
return "", fmt.Errorf("decoding TDX attestation response: %w", err)
}
if attestResponse.Token == "" {
return "", fmt.Errorf("azure TDX attestation token not found in response")
}
return attestResponse.Token, nil
}
func (c *defaultAzureTDXIMDSClient) GetQuote(ctx context.Context, tdReport []byte, client *http.Client) ([]byte, error) {
if len(tdReport) != azureTDReportSize {
return nil, fmt.Errorf("invalid TD report length: expected %d bytes, got %d bytes", azureTDReportSize, len(tdReport))
}
if client == nil {
client = http.DefaultClient
}
quoteRequest := tdxIMDSQuoteRequest{
Report: base64.RawURLEncoding.EncodeToString(tdReport),
}
reqBytes, err := json.Marshal(quoteRequest)
if err != nil {
return nil, fmt.Errorf("marshaling Azure TDX IMDS quote request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, azureTDXIMDSQuoteURL, bytes.NewReader(reqBytes))
if err != nil {
return nil, fmt.Errorf("creating Azure TDX IMDS quote request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("doing Azure TDX IMDS quote request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if msg, err := io.ReadAll(resp.Body); err == nil && len(msg) > 0 {
return nil, fmt.Errorf("Azure TDX IMDS returned %v: %s", resp.Status, msg)
}
return nil, fmt.Errorf("Azure TDX IMDS returned %v", resp.Status)
}
var quoteResponse tdxIMDSQuoteResponse
if err := json.NewDecoder(resp.Body).Decode(&quoteResponse); err != nil {
return nil, fmt.Errorf("decoding Azure TDX IMDS quote response: %w", err)
}
if quoteResponse.Quote == "" {
return nil, fmt.Errorf("Azure TDX IMDS quote not found in response")
}
quote, err := decodeBase64URL(quoteResponse.Quote)
if err != nil {
return nil, fmt.Errorf("decoding Azure TDX IMDS quote: %w", err)
}
return quote, nil
}
// FetchAzureTDXAttestationToken fetches an Azure Attestation token for an Azure TDX VM.
func FetchAzureTDXAttestationToken(tokenNonce []byte, maaURL string) ([]byte, error) {
reportData := tdxReportDataFromRuntimeData(tokenNonce)
evidence, err := fetchAzureTDXEvidence(reportData, tokenNonce)
if err != nil {
return nil, fmt.Errorf("failed to fetch Azure TDX quote: %w", err)
}
token, err := DefaultAzureTDXClient.AttestTDXVM(context.Background(), evidence.Quote, evidence.RuntimeData, tokenNonce, maaURL, http.DefaultClient)
if err != nil {
return nil, fmt.Errorf("error fetching azure TDX token: %w", err)
}
return []byte(token), nil
}
func tdxReportDataFromRuntimeData(runtimeData []byte) [tdxabi.ReportDataSize]byte {
hash := sha256.Sum256(runtimeData)
var reportData [tdxabi.ReportDataSize]byte
copy(reportData[:sha256.Size], hash[:])
return reportData
}
func fetchAzureTDXEvidence(reportData [tdxabi.ReportDataSize]byte, fallbackRuntimeData []byte) (*azureTDXEvidence, error) {
if evidenceFetcher, ok := DefaultTDXQuoteFetcher.(TDXEvidenceFetcher); ok {
return evidenceFetcher.FetchEvidence(reportData)
}
quote, err := DefaultTDXQuoteFetcher.FetchQuote(reportData)
if err != nil {
return nil, err
}
return &azureTDXEvidence{
Quote: quote,
RuntimeData: append([]byte(nil), fallbackRuntimeData...),
}, nil
}
func readFreshAzureTDXHCLReport(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
if azureTDXReportDataWriter != nil {
if err := azureTDXReportDataWriter(reportData[:]); err != nil {
return nil, fmt.Errorf("writing Azure TDX report data: %w", err)
}
if azureTDXHCLRefreshDelay > 0 {
time.Sleep(azureTDXHCLRefreshDelay)
}
}
if azureTDXHCLReportReader == nil {
return nil, fmt.Errorf("Azure TDX HCL report reader is not configured")
}
hclReport, err := azureTDXHCLReportReader()
if err != nil {
return nil, fmt.Errorf("reading Azure TDX HCL report: %w", err)
}
return hclReport, nil
}
func readAzureHCLReport() ([]byte, error) {
tpm, err := tpm2.OpenTPM()
if err != nil {
return nil, err
}
defer tpm.Close()
return tpm2.NVReadEx(tpm, azureHCLReportNVIndex, tpm2.HandleOwner, "", 0)
}
func writeAzureTDXReportData(data []byte) error {
if len(data) != tdxabi.ReportDataSize {
return fmt.Errorf("invalid Azure TDX report data length: expected %d bytes, got %d bytes", tdxabi.ReportDataSize, len(data))
}
tpm, err := tpm2.OpenTPM()
if err != nil {
return err
}
defer tpm.Close()
return writeAzureTDXReportDataToTPM(tpm, data)
}
func writeAzureTDXReportDataToTPM(tpm io.ReadWriter, data []byte) error {
if len(data) > int(^uint16(0)) {
return fmt.Errorf("Azure TDX report data is too large")
}
if err := ensureAzureTDXReportDataIndex(tpm, uint16(len(data))); err != nil {
return err
}
if err := tpm2.NVWrite(tpm, tpm2.HandleOwner, azureHCLReportDataNVIndex, "", tpmutil.U16Bytes(data), 0); err != nil {
return fmt.Errorf("writing Azure TDX report-data NV index: %w", err)
}
return nil
}
func ensureAzureTDXReportDataIndex(tpm io.ReadWriter, size uint16) error {
pub, err := tpm2.NVReadPublic(tpm, azureHCLReportDataNVIndex)
if err == nil {
if pub.DataSize == size {
return nil
}
if err := tpm2.NVUndefineSpace(tpm, "", tpm2.HandleOwner, azureHCLReportDataNVIndex); err != nil {
return fmt.Errorf("undefining mismatched Azure TDX report-data NV index: %w", err)
}
}
nvPub := tpm2.NVPublic{
NVIndex: azureHCLReportDataNVIndex,
NameAlg: tpm2.AlgSHA256,
Attributes: tpm2.AttrOwnerWrite | tpm2.AttrOwnerRead,
DataSize: size,
}
authArea := tpm2.AuthCommand{
Session: tpm2.HandlePasswordSession,
Attributes: tpm2.AttrContinueSession,
Auth: []byte(""),
}
if err := tpm2.NVDefineSpaceEx(tpm, tpm2.HandleOwner, "", nvPub, authArea); err != nil {
return fmt.Errorf("defining Azure TDX report-data NV index: %w", err)
}
return nil
}
func parseAzureHCLReport(report []byte) (*azureHCLReport, error) {
minSize := azureHCLRuntimeDataOffset + azureHCLRuntimeClaimsOffset
if len(report) < minSize {
return nil, fmt.Errorf("invalid Azure HCL report size: expected at least %d bytes, got %d bytes", minSize, len(report))
}
if string(report[:len(azureHCLSignature)]) != azureHCLSignature {
return nil, fmt.Errorf("invalid Azure HCL report signature")
}
if version := binary.LittleEndian.Uint32(report[4:8]); version != azureHCLVersion {
return nil, fmt.Errorf("invalid Azure HCL report version: expected %d, got %d", azureHCLVersion, version)
}
reportSize := binary.LittleEndian.Uint32(report[8:12])
if reportSize > uint32(len(report)) {
return nil, fmt.Errorf("invalid Azure HCL report size: header reports %d bytes, got %d bytes", reportSize, len(report))
}
if requestType := binary.LittleEndian.Uint32(report[12:16]); requestType != azureHCLRequestType {
return nil, fmt.Errorf("invalid Azure HCL report request type: expected %d, got %d", azureHCLRequestType, requestType)
}
runtimeData := report[azureHCLRuntimeDataOffset:]
dataSize := binary.LittleEndian.Uint32(runtimeData[0:4])
if dataSize < azureHCLRuntimeClaimsOffset {
return nil, fmt.Errorf("invalid Azure HCL runtime data size: %d", dataSize)
}
if azureHCLRuntimeDataOffset+int(dataSize) > len(report) {
return nil, fmt.Errorf("invalid Azure HCL runtime data size: header reports %d bytes", dataSize)
}
if version := binary.LittleEndian.Uint32(runtimeData[4:8]); version != azureHCLRuntimeDataVersion {
return nil, fmt.Errorf("invalid Azure HCL runtime data version: expected %d, got %d", azureHCLRuntimeDataVersion, version)
}
reportType := binary.LittleEndian.Uint32(runtimeData[8:12])
if reportType != azureHCLReportTypeSNP && reportType != azureHCLReportTypeTDX {
return nil, fmt.Errorf("invalid Azure HCL report type: %d", reportType)
}
hashType := binary.LittleEndian.Uint32(runtimeData[12:16])
claimsSize := binary.LittleEndian.Uint32(runtimeData[16:20])
claimsEnd := azureHCLRuntimeClaimsOffset + int(claimsSize)
if claimsEnd > int(dataSize) {
return nil, fmt.Errorf("invalid Azure HCL runtime claims size: %d", claimsSize)
}
hwReportSize := azureHCLMaxHWReportSize
if reportType == azureHCLReportTypeTDX {
hwReportSize = azureTDReportSize
}
if azureHCLHeaderSize+hwReportSize > len(report) {
return nil, fmt.Errorf("invalid Azure HCL hardware report size: %d", hwReportSize)
}
return &azureHCLReport{
reportType: reportType,
hashType: hashType,
hwReport: append([]byte(nil), report[azureHCLHeaderSize:azureHCLHeaderSize+hwReportSize]...),
runtimeClaims: append([]byte(nil), runtimeData[azureHCLRuntimeClaimsOffset:claimsEnd]...),
}, nil
}
func validateAzureTDXRuntimeClaimsHash(report *azureHCLReport) error {
if report.hashType != azureHCLHashSHA256 {
return fmt.Errorf("unsupported Azure HCL runtime data hash type: %d", report.hashType)
}
if len(report.hwReport) < azureTDReportDataOffset+sha256.Size {
return fmt.Errorf("invalid Azure TDX TD report size: %d", len(report.hwReport))
}
hash := sha256.Sum256(report.runtimeClaims)
if !bytes.Equal(hash[:], report.hwReport[azureTDReportDataOffset:azureTDReportDataOffset+sha256.Size]) {
return fmt.Errorf("Azure TDX runtime claims hash does not match TD report data")
}
return nil
}
func tdxRuntimeDataType(runtimeData []byte) string {
if json.Valid(runtimeData) {
return tdxRuntimeJSON
}
return tdxRuntimeBinary
}
func decodeBase64URL(value string) ([]byte, error) {
decoded, err := base64.RawURLEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
return base64.URLEncoding.DecodeString(value)
}
func verifyTDXQuoteWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
decodedQuote, err := tdxabi.QuoteToProto(report)
if err != nil {
return fmt.Errorf("failed to parse TDX quote: %w", err)
}
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
if !ok {
return fmt.Errorf("unsupported TDX quote format")
}
tdReport := quoteV4.GetTdQuoteBody()
if tdReport == nil {
return fmt.Errorf("missing TDX quote body")
}
mrtd := tdReport.GetMrTd()
if len(mrtd) == 0 {
return fmt.Errorf("no MRTD in TDX quote")
}
if err := matchMeasurementInCoRIM(manifest, mrtd); err != nil {
return fmt.Errorf("%w for Azure TDX", err)
}
return nil
}
func matchMeasurementInCoRIM(manifest *corim.UnsignedCorim, measurement []byte) error {
if manifest == nil || len(manifest.Tags) == 0 {
return fmt.Errorf("no tags in CoRIM")
}
for _, tag := range manifest.Tags {
if !bytes.HasPrefix(tag, corim.ComidTag) {
continue
}
tagValue := tag[len(corim.ComidTag):]
var c comid.Comid
if err := c.FromCBOR(tagValue); err != nil {
return fmt.Errorf("failed to parse CoMID: %w", err)
}
if c.Triples.ReferenceValues == nil {
continue
}
for _, rv := range *c.Triples.ReferenceValues {
for _, m := range rv.Measurements {
if m.Val.Digests == nil {
continue
}
for _, digest := range *m.Val.Digests {
if bytes.Equal(digest.HashValue, measurement) {
return nil
}
}
}
}
}
return fmt.Errorf("no matching reference value found in CoRIM")
}
// AzureTDXMeasurementData contains the fields extracted from an Azure TDX attestation token
// needed to construct a CoRIM policy for the TDX platform.
type AzureTDXMeasurementData struct {
MRTD string
MRSEAM string
RTMRs []string
SEAMSVN uint64
}
// ExtractAzureTDXMeasurement extracts core TDX measurements from an Azure Attestation token.
func ExtractAzureTDXMeasurement(token string) (*AzureTDXMeasurementData, error) {
claims, err := DefaultValidator.Validate(token)
if err != nil {
return nil, fmt.Errorf("failed to validate token: %w", err)
}
mrtd, ok := azureClaimString(claims, "tdx_mrtd")
if !ok {
return nil, fmt.Errorf("failed to get MRTD from claims")
}
mrSeam, _ := azureClaimString(claims, "tdx_mrseam")
rtmrs := make([]string, 0, 4)
for _, name := range []string{"tdx_rtmr0", "tdx_rtmr1", "tdx_rtmr2", "tdx_rtmr3"} {
if value, ok := azureClaimString(claims, name); ok {
rtmrs = append(rtmrs, value)
}
}
seamSVN, _ := azureClaimUint64(claims, "tdx_seamsvn")
return &AzureTDXMeasurementData{
MRTD: mrtd,
MRSEAM: mrSeam,
RTMRs: rtmrs,
SEAMSVN: seamSVN,
}, nil
}
func azureClaimString(claims map[string]any, name string) (string, bool) {
if value, ok := claims[name].(string); ok {
return value, true
}
tee, ok := claims["x-ms-isolation-tee"].(map[string]any)
if !ok {
return "", false
}
value, ok := tee[name].(string)
return value, ok
}
func azureClaimUint64(claims map[string]any, name string) (uint64, bool) {
value, ok := claims[name]
if !ok {
tee, teeOK := claims["x-ms-isolation-tee"].(map[string]any)
if !teeOK {
return 0, false
}
value, ok = tee[name]
if !ok {
return 0, false
}
}
switch typed := value.(type) {
case float64:
return uint64(typed), true
case int:
return uint64(typed), true
case uint64:
return typed, true
default:
return 0, false
}
}
+455
View File
@@ -0,0 +1,455 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package azure
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
tdxabi "github.com/google/go-tdx-guest/abi"
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
tdxtestdata "github.com/google/go-tdx-guest/testing/testdata"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/veraison/corim/comid"
"github.com/veraison/corim/corim"
"github.com/veraison/swid"
)
type mockTDXQuoteFetcher struct {
quote []byte
err error
gotReportData [tdxabi.ReportDataSize]byte
fetchQuoteCall bool
}
func (m *mockTDXQuoteFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
m.fetchQuoteCall = true
m.gotReportData = reportData
if m.err != nil {
return nil, m.err
}
return m.quote, nil
}
type mockTDXEvidenceFetcher struct {
evidence *azureTDXEvidence
err error
gotReportData [tdxabi.ReportDataSize]byte
fetchEvidenceHit bool
}
func (m *mockTDXEvidenceFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
evidence, err := m.FetchEvidence(reportData)
if err != nil {
return nil, err
}
return evidence.Quote, nil
}
func (m *mockTDXEvidenceFetcher) FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error) {
m.fetchEvidenceHit = true
m.gotReportData = reportData
if m.err != nil {
return nil, m.err
}
return m.evidence, nil
}
type mockAzureTDXClient struct {
token string
err error
gotQuote []byte
gotRuntime []byte
gotNonce []byte
gotMaaURL string
attestCalls int
}
func (m *mockAzureTDXClient) AttestTDXVM(_ context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, _ *http.Client) (string, error) {
m.attestCalls++
m.gotQuote = append([]byte(nil), quote...)
m.gotRuntime = append([]byte(nil), runtimeData...)
m.gotNonce = append([]byte(nil), nonce...)
m.gotMaaURL = maaURL
if m.err != nil {
return "", m.err
}
return m.token, nil
}
type mockAzureTDXIMDSClient struct {
quote []byte
err error
gotTDReport []byte
getQuoteCall bool
}
func (m *mockAzureTDXIMDSClient) GetQuote(_ context.Context, tdReport []byte, _ *http.Client) ([]byte, error) {
m.getQuoteCall = true
m.gotTDReport = append([]byte(nil), tdReport...)
if m.err != nil {
return nil, m.err
}
return m.quote, nil
}
func testAzureHCLReport(reportType uint32, runtimeClaims []byte) []byte {
reportSize := azureHCLRuntimeDataOffset + azureHCLRuntimeClaimsOffset + len(runtimeClaims)
hclReport := make([]byte, reportSize)
copy(hclReport[:len(azureHCLSignature)], azureHCLSignature)
binary.LittleEndian.PutUint32(hclReport[4:8], azureHCLVersion)
binary.LittleEndian.PutUint32(hclReport[8:12], uint32(reportSize))
binary.LittleEndian.PutUint32(hclReport[12:16], azureHCLRequestType)
runtimeData := hclReport[azureHCLRuntimeDataOffset:]
binary.LittleEndian.PutUint32(runtimeData[0:4], uint32(azureHCLRuntimeClaimsOffset+len(runtimeClaims)))
binary.LittleEndian.PutUint32(runtimeData[4:8], azureHCLRuntimeDataVersion)
binary.LittleEndian.PutUint32(runtimeData[8:12], reportType)
binary.LittleEndian.PutUint32(runtimeData[12:16], azureHCLHashSHA256)
binary.LittleEndian.PutUint32(runtimeData[16:20], uint32(len(runtimeClaims)))
copy(runtimeData[azureHCLRuntimeClaimsOffset:], runtimeClaims)
if reportType == azureHCLReportTypeTDX {
hash := sha256.Sum256(runtimeClaims)
copy(hclReport[azureHCLHeaderSize+azureTDReportDataOffset:], hash[:])
}
return hclReport
}
func TestProvider_TeeAttestation_AzureTDX(t *testing.T) {
oldReader := azureTDXHCLReportReader
oldFetcher := DefaultTDXQuoteFetcher
defer func() {
azureTDXHCLReportReader = oldReader
DefaultTDXQuoteFetcher = oldFetcher
}()
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
}
fetcher := &mockTDXQuoteFetcher{quote: []byte("tdx-quote")}
DefaultTDXQuoteFetcher = fetcher
reportData := bytes.Repeat([]byte{0xAB}, tdxabi.ReportDataSize)
got, err := NewProvider().TeeAttestation(reportData)
require.NoError(t, err)
assert.Equal(t, []byte("tdx-quote"), got)
assert.True(t, fetcher.fetchQuoteCall)
assert.Equal(t, reportData, fetcher.gotReportData[:])
}
func TestProvider_TeeAttestation_AzureTDX_InvalidNonce(t *testing.T) {
oldReader := azureTDXHCLReportReader
defer func() { azureTDXHCLReportReader = oldReader }()
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
}
_, err := NewProvider().TeeAttestation([]byte("short"))
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid tee nonce length")
}
func TestProvider_AzureAttestationToken_AzureTDX(t *testing.T) {
oldReader := azureTDXHCLReportReader
oldFetcher := DefaultTDXQuoteFetcher
oldClient := DefaultAzureTDXClient
oldMaaURL := MaaURL
defer func() {
azureTDXHCLReportReader = oldReader
DefaultTDXQuoteFetcher = oldFetcher
DefaultAzureTDXClient = oldClient
MaaURL = oldMaaURL
}()
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
}
MaaURL = "https://tdx.example.attest.azure.net"
fetcher := &mockTDXQuoteFetcher{quote: []byte("quote")}
client := &mockAzureTDXClient{token: "tdx-token"}
DefaultTDXQuoteFetcher = fetcher
DefaultAzureTDXClient = client
nonce := []byte("token-nonce")
got, err := NewProvider().AzureAttestationToken(nonce)
require.NoError(t, err)
assert.Equal(t, []byte("tdx-token"), got)
expectedReportData := tdxReportDataFromRuntimeData(nonce)
assert.Equal(t, expectedReportData, fetcher.gotReportData)
assert.Equal(t, []byte("quote"), client.gotQuote)
assert.Equal(t, nonce, client.gotRuntime)
assert.Equal(t, nonce, client.gotNonce)
assert.Equal(t, MaaURL, client.gotMaaURL)
}
func TestFetchAzureTDXAttestationToken_UsesHCLRuntimeClaims(t *testing.T) {
oldFetcher := DefaultTDXQuoteFetcher
oldClient := DefaultAzureTDXClient
defer func() {
DefaultTDXQuoteFetcher = oldFetcher
DefaultAzureTDXClient = oldClient
}()
runtimeClaims := []byte(`{"keys":[],"user-data":"nonce"}`)
fetcher := &mockTDXEvidenceFetcher{
evidence: &azureTDXEvidence{
Quote: []byte("quote"),
RuntimeData: runtimeClaims,
},
}
client := &mockAzureTDXClient{token: "tdx-token"}
DefaultTDXQuoteFetcher = fetcher
DefaultAzureTDXClient = client
nonce := []byte("token-nonce")
got, err := FetchAzureTDXAttestationToken(nonce, "https://tdx.example.attest.azure.net")
require.NoError(t, err)
assert.Equal(t, []byte("tdx-token"), got)
assert.True(t, fetcher.fetchEvidenceHit)
assert.Equal(t, tdxReportDataFromRuntimeData(nonce), fetcher.gotReportData)
assert.Equal(t, runtimeClaims, client.gotRuntime)
assert.Equal(t, nonce, client.gotNonce)
}
func TestFetchAzureTDXAttestationToken_FetchQuoteError(t *testing.T) {
oldFetcher := DefaultTDXQuoteFetcher
defer func() { DefaultTDXQuoteFetcher = oldFetcher }()
DefaultTDXQuoteFetcher = &mockTDXQuoteFetcher{err: fmt.Errorf("quote unavailable")}
_, err := FetchAzureTDXAttestationToken([]byte("nonce"), "https://tdx.example.attest.azure.net")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to fetch Azure TDX quote")
}
func TestDefaultTDXQuoteFetcher_FetchEvidence_AzureHCLIMDS(t *testing.T) {
oldReader := azureTDXHCLReportReader
oldWriter := azureTDXReportDataWriter
oldDelay := azureTDXHCLRefreshDelay
oldIMDSClient := DefaultAzureTDXIMDSClient
defer func() {
azureTDXHCLReportReader = oldReader
azureTDXReportDataWriter = oldWriter
azureTDXHCLRefreshDelay = oldDelay
DefaultAzureTDXIMDSClient = oldIMDSClient
}()
runtimeClaims := []byte(`{"keys":[],"user-data":"fresh"}`)
hclReport := testAzureHCLReport(azureHCLReportTypeTDX, runtimeClaims)
var gotReportData []byte
azureTDXReportDataWriter = func(data []byte) error {
gotReportData = append([]byte(nil), data...)
return nil
}
azureTDXHCLReportReader = func() ([]byte, error) {
return hclReport, nil
}
azureTDXHCLRefreshDelay = 0
imdsClient := &mockAzureTDXIMDSClient{quote: []byte("tdx-quote")}
DefaultAzureTDXIMDSClient = imdsClient
reportData := [tdxabi.ReportDataSize]byte{0xAB}
evidence, err := defaultTDXQuoteFetcher{}.FetchEvidence(reportData)
require.NoError(t, err)
assert.Equal(t, reportData[:], gotReportData)
assert.Equal(t, []byte("tdx-quote"), evidence.Quote)
assert.Equal(t, runtimeClaims, evidence.RuntimeData)
assert.True(t, imdsClient.getQuoteCall)
assert.Len(t, imdsClient.gotTDReport, azureTDReportSize)
}
func TestDefaultAzureTDXClient_AttestTDXVM(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/attest/TdxVm", r.URL.Path)
assert.Equal(t, tdxAPIVersion, r.URL.Query().Get("api-version"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
var req tdxAttestRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("quote")), req.Quote)
require.NotNil(t, req.RuntimeData)
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("runtime")), req.RuntimeData.Data)
assert.Equal(t, tdxRuntimeBinary, req.RuntimeData.DataType)
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("nonce")), req.Nonce)
_, _ = w.Write([]byte(`{"token":"tdx-token"}`))
}))
defer server.Close()
token, err := (&defaultAzureTDXClient{}).AttestTDXVM(
context.Background(),
[]byte("quote"),
[]byte("runtime"),
[]byte("nonce"),
server.URL,
server.Client(),
)
require.NoError(t, err)
assert.Equal(t, "tdx-token", token)
}
func TestDefaultAzureTDXClient_AttestTDXVM_JSONRuntimeData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req tdxAttestRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
require.NotNil(t, req.RuntimeData)
assert.Equal(t, tdxRuntimeJSON, req.RuntimeData.DataType)
_, _ = w.Write([]byte(`{"token":"tdx-token"}`))
}))
defer server.Close()
_, err := (&defaultAzureTDXClient{}).AttestTDXVM(
context.Background(),
[]byte("quote"),
[]byte(`{"keys":[]}`),
[]byte("nonce"),
server.URL,
server.Client(),
)
require.NoError(t, err)
}
func TestDefaultAzureTDXClient_AttestTDXVM_ErrorStatus(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad quote", http.StatusBadRequest)
}))
defer server.Close()
_, err := (&defaultAzureTDXClient{}).AttestTDXVM(context.Background(), []byte("quote"), nil, nil, server.URL, server.Client())
require.Error(t, err)
assert.Contains(t, err.Error(), "MAA returned 400 Bad Request")
}
func TestDefaultAzureTDXIMDSClient_GetQuote(t *testing.T) {
oldURL := azureTDXIMDSQuoteURL
defer func() { azureTDXIMDSQuoteURL = oldURL }()
tdReport := bytes.Repeat([]byte{0xA5}, azureTDReportSize)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
var req tdxIMDSQuoteRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, base64.RawURLEncoding.EncodeToString(tdReport), req.Report)
resp := tdxIMDSQuoteResponse{
Quote: base64.RawURLEncoding.EncodeToString([]byte("quote")),
}
require.NoError(t, json.NewEncoder(w).Encode(resp))
}))
defer server.Close()
azureTDXIMDSQuoteURL = server.URL
quote, err := (&defaultAzureTDXIMDSClient{}).GetQuote(context.Background(), tdReport, server.Client())
require.NoError(t, err)
assert.Equal(t, []byte("quote"), quote)
}
func TestIsAzureTDX_UsesHCLReportType(t *testing.T) {
oldReader := azureTDXHCLReportReader
defer func() { azureTDXHCLReportReader = oldReader }()
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeSNP, []byte("runtime")), nil
}
assert.False(t, isAzureTDX())
azureTDXHCLReportReader = func() ([]byte, error) {
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
}
assert.True(t, isAzureTDX())
azureTDXHCLReportReader = func() ([]byte, error) {
return nil, fmt.Errorf("no vTPM")
}
assert.False(t, isAzureTDX())
}
func TestExtractAzureTDXMeasurement_Success(t *testing.T) {
oldValidator := DefaultValidator
defer func() { DefaultValidator = oldValidator }()
DefaultValidator = &mockTokenValidator{
validateFunc: func(token string) (map[string]any, error) {
return map[string]any{
"tdx_mrtd": "mrtd",
"tdx_mrseam": "mrseam",
"tdx_rtmr0": "rtmr0",
"tdx_rtmr1": "rtmr1",
"tdx_rtmr2": "rtmr2",
"tdx_rtmr3": "rtmr3",
"tdx_seamsvn": float64(7),
"unrelatedKey": "ignored",
}, nil
},
}
data, err := ExtractAzureTDXMeasurement("valid-token")
require.NoError(t, err)
assert.Equal(t, &AzureTDXMeasurementData{
MRTD: "mrtd",
MRSEAM: "mrseam",
RTMRs: []string{"rtmr0", "rtmr1", "rtmr2", "rtmr3"},
SEAMSVN: 7,
}, data)
}
func TestVerifier_VerifyWithCoRIM_AzureTDX(t *testing.T) {
decodedQuote, err := tdxabi.QuoteToProto(tdxtestdata.RawQuote)
require.NoError(t, err)
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
require.True(t, ok)
mrtd := quoteV4.GetTdQuoteBody().GetMrTd()
c := comid.NewComid()
c.SetTagIdentity("tdx-tag", 0)
m := comid.MustNewUintMeasurement(uint64(1))
m.AddDigest(swid.Sha384, mrtd)
m.SetRawValueBytes([]byte("raw"), nil)
rv := comid.ReferenceValue{
Environment: comid.Environment{
Class: comid.NewClassOID("1.2.3.4"),
},
Measurements: comid.Measurements{*m},
}
c.AddReferenceValue(rv)
manifest := corim.NewUnsignedCorim()
manifest.SetID("test-tdx-corim")
manifest.AddComid(*c)
err = NewVerifier(&bytes.Buffer{}).VerifyWithCoRIM(tdxtestdata.RawQuote, manifest)
assert.NoError(t, err)
}
+4
View File
@@ -55,6 +55,8 @@ func (c *client) GetAttestation(ctx context.Context, reportData [64]byte, nonce
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
case attestation.SNPvTPM:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
case attestation.Azure:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_AZURE
default:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
}
@@ -92,6 +94,8 @@ func (c *client) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
case attestation.SNPvTPM:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
case attestation.Azure:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_AZURE
default:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
}
@@ -174,6 +174,40 @@ func TestGetAttestationTDX(t *testing.T) {
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
}
// TestGetAttestationAzure tests getting Azure attestation.
func TestGetAttestationAzure(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-azure-evidence.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.Azure)
require.NoError(t, err)
assert.NotNil(t, quote)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_AZURE, mockServer.lastPlatformType)
}
// TestGetAttestationVTPM tests getting vTPM attestation.
func TestGetAttestationVTPM(t *testing.T) {
tmpDir := t.TempDir()
@@ -479,6 +513,40 @@ func TestGetRawEvidenceTDX(t *testing.T) {
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
}
// TestGetRawEvidenceAzure tests getting raw evidence for Azure platform.
func TestGetRawEvidenceAzure(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "raw-evidence-azure.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.Azure)
require.NoError(t, err)
assert.NotNil(t, evidence)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_AZURE, mockServer.lastPlatformType)
}
// TestGetRawEvidenceVTPM tests getting raw evidence for VTPM platform.
func TestGetRawEvidenceVTPM(t *testing.T) {
tmpDir := t.TempDir()