mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-22 20:00:18 +00:00
NOISSUE - Azure TDX Support (#596)
* initial Azure TDX support * add tests * update documentation --------- Co-authored-by: Ubuntu <danko@cocos.nbzvzgavv4yeximq0jorvcggfd.dx.internal.cloudapp.net>
This commit is contained in:
committed by
GitHub
parent
27db9b29eb
commit
02aa7d7d85
@@ -21,12 +21,20 @@ The service is configured using the environment variables from the following tab
|
||||
| AGENT_CVM_ID | Unique identifier for the CVM (Confidential Virtual Machine) | "" |
|
||||
| AGENT_CERTS_TOKEN | Authentication token for certificate service access | "" |
|
||||
| AGENT_MAA_URL | Microsoft Azure Attestation service URL for Azure attestation | https://sharedeus2.eus2.attest.azure.net |
|
||||
| AZURE_TDX_IMDS_URL | Azure TDX quote endpoint used by direct Azure TDX attestation | http://169.254.169.254/acc/tdquote |
|
||||
| AZURE_HCL_REFRESH_WAIT | Wait after writing TDX report data to Azure HCL vTPM storage before reading the refreshed HCL report | 3s |
|
||||
| AGENT_OS_BUILD | Operating system build information for attestation | UVC |
|
||||
| AGENT_OS_DISTRO | Operating system distribution information for attestation | UVC |
|
||||
| AGENT_OS_TYPE | Operating system type information for attestation | UVC |
|
||||
| ATTESTATION_SERVICE_SOCKET | Unix socket path for attestation service communication | /run/cocos/attestation.sock |
|
||||
| AGENT_ENABLE_ATLS | Enable Attestation TLS for secure communication | true |
|
||||
|
||||
### Azure TDX Attestation
|
||||
|
||||
When the agent runs on an Azure TDX CVM, Azure attestation uses the direct Azure TDX flow. The agent writes TDX report data to Azure HCL vTPM storage, reads the refreshed HCL report, requests a TD quote from Azure IMDS, and submits the quote plus HCL runtime data to Microsoft Azure Attestation. This path does not depend on Confidential Containers attestation-agent `GetEvidence` or KBS token retrieval.
|
||||
|
||||
`AGENT_MAA_URL` selects the Microsoft Azure Attestation endpoint. `AZURE_TDX_IMDS_URL` can override the Azure IMDS TDX quote endpoint, and `AZURE_HCL_REFRESH_WAIT` controls the wait used to avoid reading a stale HCL report after report-data is written.
|
||||
|
||||
### Remote Resource Download (Optional)
|
||||
|
||||
The agent supports downloading encrypted algorithms and datasets from remote registries (S3, HTTP/HTTPS) and retrieving decryption keys from a Key Broker Service (KBS) via attestation.
|
||||
|
||||
@@ -60,7 +60,7 @@ func (req azureAttestationTokenReq) validate() error {
|
||||
|
||||
func validateAttestationType(attType attestation.PlatformType) error {
|
||||
switch attType {
|
||||
case attestation.SNP, attestation.VTPM, attestation.SNPvTPM, attestation.TDX:
|
||||
case attestation.SNP, attestation.VTPM, attestation.SNPvTPM, attestation.Azure, attestation.TDX:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid attestation type")
|
||||
|
||||
@@ -33,6 +33,12 @@ func (s *service) FetchRawEvidence(ctx context.Context, req *attestationpb.Attes
|
||||
var nonce [32]byte
|
||||
copy(nonce[:], req.Nonce)
|
||||
binaryReport, err = s.provider.Attestation(reportData[:], nonce[:])
|
||||
case attestationpb.PlatformType_PLATFORM_TYPE_AZURE:
|
||||
var reportData [64]byte
|
||||
copy(reportData[:], req.ReportData)
|
||||
var nonce [32]byte
|
||||
copy(nonce[:], req.Nonce)
|
||||
binaryReport, err = s.provider.Attestation(reportData[:], nonce[:])
|
||||
case attestationpb.PlatformType_PLATFORM_TYPE_UNSPECIFIED:
|
||||
// Generate sample attestation for testing in non-TEE environments
|
||||
// This uses the underlying provider (EmptyProvider or CC Attestation Agent)
|
||||
|
||||
@@ -318,6 +318,13 @@ func (s *service) FetchAttestation(ctx context.Context, req *attestationpb.Attes
|
||||
copy(nonce[:], req.Nonce)
|
||||
binaryReport, err = s.provider.Attestation(reportData[:], nonce[:])
|
||||
platformType = attestation.SNPvTPM
|
||||
case attestationpb.PlatformType_PLATFORM_TYPE_AZURE:
|
||||
var reportData [64]byte
|
||||
copy(reportData[:], req.ReportData)
|
||||
var nonce [32]byte
|
||||
copy(nonce[:], req.Nonce)
|
||||
binaryReport, err = s.provider.Attestation(reportData[:], nonce[:])
|
||||
platformType = attestation.Azure
|
||||
case attestationpb.PlatformType_PLATFORM_TYPE_UNSPECIFIED:
|
||||
// Generate sample attestation for testing in non-TEE environments
|
||||
s.logger.Warn("generating sample attestation for PLATFORM_TYPE_UNSPECIFIED - this should only be used for testing")
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/go-azguestattestation/maa"
|
||||
)
|
||||
|
||||
@@ -28,6 +32,27 @@ func InitializeDefaultMAAVars(config *EnvConfig) {
|
||||
maa.OSType = config.OSType
|
||||
maa.OSDistro = config.OSDistro
|
||||
MaaURL = config.MaaURL
|
||||
InitializeDefaultAzureTDXVarsFromEnv()
|
||||
}
|
||||
|
||||
func InitializeDefaultAzureTDXVars(imdsURL string, hclRefreshDelay time.Duration) {
|
||||
if imdsURL = strings.TrimSpace(imdsURL); imdsURL != "" {
|
||||
azureTDXIMDSQuoteURL = imdsURL
|
||||
}
|
||||
if hclRefreshDelay >= 0 {
|
||||
azureTDXHCLRefreshDelay = hclRefreshDelay
|
||||
}
|
||||
}
|
||||
|
||||
func InitializeDefaultAzureTDXVarsFromEnv() {
|
||||
imdsURL := os.Getenv("AZURE_TDX_IMDS_URL")
|
||||
hclRefreshDelay := azureTDXHCLRefreshDelay
|
||||
if value := strings.TrimSpace(os.Getenv("AZURE_HCL_REFRESH_WAIT")); value != "" {
|
||||
if parsed, err := time.ParseDuration(value); err == nil {
|
||||
hclRefreshDelay = parsed
|
||||
}
|
||||
}
|
||||
InitializeDefaultAzureTDXVars(imdsURL, hclRefreshDelay)
|
||||
}
|
||||
|
||||
func (c *EnvConfig) InitializeOSVars(build, osType, osDistro string) {
|
||||
|
||||
@@ -5,6 +5,7 @@ package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/go-azguestattestation/maa"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -56,3 +57,34 @@ func TestInitializeOSVars(t *testing.T) {
|
||||
assert.Equal(t, "TypeY", cfg.OSType)
|
||||
assert.Equal(t, "DistroZ", cfg.OSDistro)
|
||||
}
|
||||
|
||||
func TestInitializeDefaultAzureTDXVars(t *testing.T) {
|
||||
oldURL := azureTDXIMDSQuoteURL
|
||||
oldDelay := azureTDXHCLRefreshDelay
|
||||
defer func() {
|
||||
azureTDXIMDSQuoteURL = oldURL
|
||||
azureTDXHCLRefreshDelay = oldDelay
|
||||
}()
|
||||
|
||||
InitializeDefaultAzureTDXVars(" https://imds.example/tdquote ", 1500*time.Millisecond)
|
||||
|
||||
assert.Equal(t, "https://imds.example/tdquote", azureTDXIMDSQuoteURL)
|
||||
assert.Equal(t, 1500*time.Millisecond, azureTDXHCLRefreshDelay)
|
||||
}
|
||||
|
||||
func TestInitializeDefaultAzureTDXVarsFromEnv(t *testing.T) {
|
||||
oldURL := azureTDXIMDSQuoteURL
|
||||
oldDelay := azureTDXHCLRefreshDelay
|
||||
defer func() {
|
||||
azureTDXIMDSQuoteURL = oldURL
|
||||
azureTDXHCLRefreshDelay = oldDelay
|
||||
}()
|
||||
|
||||
t.Setenv("AZURE_TDX_IMDS_URL", "https://env-imds.example/tdquote")
|
||||
t.Setenv("AZURE_HCL_REFRESH_WAIT", "2s")
|
||||
|
||||
InitializeDefaultAzureTDXVarsFromEnv()
|
||||
|
||||
assert.Equal(t, "https://env-imds.example/tdquote", azureTDXIMDSQuoteURL)
|
||||
assert.Equal(t, 2*time.Second, azureTDXHCLRefreshDelay)
|
||||
}
|
||||
|
||||
@@ -52,6 +52,10 @@ func NewProvider() attestation.Provider {
|
||||
}
|
||||
|
||||
func (a provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error) {
|
||||
if isAzureTDX() {
|
||||
return a.TeeAttestation(teeNonce)
|
||||
}
|
||||
|
||||
var tokenNonce [vtpm.Nonce]byte
|
||||
copy(tokenNonce[:], teeNonce)
|
||||
|
||||
@@ -77,6 +81,10 @@ func (a provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
func (a provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
|
||||
if isAzureTDX() {
|
||||
return fetchAzureTDXQuote(teeNonce)
|
||||
}
|
||||
|
||||
var tokenNonce [vtpm.Nonce]byte
|
||||
copy(tokenNonce[:], teeNonce)
|
||||
|
||||
@@ -89,7 +97,6 @@ func (a provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (a provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
|
||||
fmt.Printf("DEBUG: VTpmAttestation: vtpm.ExternalTPM is %T at %p\n", vtpm.ExternalTPM, &vtpm.ExternalTPM)
|
||||
quote, err := vtpm.FetchQuote(vTpmNonce)
|
||||
if err != nil {
|
||||
return []byte{}, errors.Wrap(vtpm.ErrFetchQuote, err)
|
||||
@@ -111,6 +118,14 @@ func (c *defaultMaaClient) Attest(ctx context.Context, nonce []byte, maaURL stri
|
||||
var DefaultMaaClient MaaClient = &defaultMaaClient{}
|
||||
|
||||
func (a provider) AzureAttestationToken(tokenNonce []byte) ([]byte, error) {
|
||||
if isAzureTDX() {
|
||||
token, err := FetchAzureTDXAttestationToken(tokenNonce, MaaURL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFetchAzureToken, err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
token, err := DefaultMaaClient.Attest(context.Background(), tokenNonce, MaaURL, http.DefaultClient)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFetchAzureToken, err)
|
||||
@@ -152,12 +167,19 @@ func (v verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte)
|
||||
func (v verifier) VerifyWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
|
||||
attestation := &attest.Attestation{}
|
||||
if err := proto.Unmarshal(report, attestation); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal attestation report: %w", err)
|
||||
tdxErr := verifyTDXQuoteWithCoRIM(report, manifest)
|
||||
if tdxErr == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to unmarshal attestation report: %w; Azure TDX verification failed: %v", err, tdxErr)
|
||||
}
|
||||
|
||||
// Extract measurement from SEV-SNP report if present
|
||||
snpRep := attestation.GetSevSnpAttestation()
|
||||
if snpRep == nil {
|
||||
if tdxErr := verifyTDXQuoteWithCoRIM(report, manifest); tdxErr == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("no SEV-SNP attestation found in report")
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,672 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
tdxabi "github.com/google/go-tdx-guest/abi"
|
||||
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
|
||||
"github.com/google/go-tpm/legacy/tpm2"
|
||||
"github.com/google/go-tpm/tpmutil"
|
||||
"github.com/veraison/corim/comid"
|
||||
"github.com/veraison/corim/corim"
|
||||
)
|
||||
|
||||
const (
|
||||
tdxAttestEndpoint = "attest/TdxVm"
|
||||
tdxAPIVersion = "2025-06-01"
|
||||
tdxRuntimeBinary = "Binary"
|
||||
tdxRuntimeJSON = "JSON"
|
||||
|
||||
azureHCLReportNVIndex = 0x01400001
|
||||
azureHCLReportDataNVIndex = 0x01400002
|
||||
|
||||
azureHCLSignature = "HCLA"
|
||||
azureHCLVersion = 2
|
||||
azureHCLRequestType = 2
|
||||
azureHCLRuntimeDataVersion = 1
|
||||
azureHCLHashSHA256 = 1
|
||||
azureHCLReportTypeSNP = 2
|
||||
azureHCLReportTypeTDX = 4
|
||||
|
||||
azureHCLHeaderSize = 0x20
|
||||
azureHCLMaxHWReportSize = 0x4a0
|
||||
azureHCLRuntimeDataOffset = azureHCLHeaderSize + azureHCLMaxHWReportSize
|
||||
azureHCLRuntimeClaimsOffset = 0x14
|
||||
azureTDReportSize = 0x400
|
||||
azureTDReportDataOffset = 0x80
|
||||
)
|
||||
|
||||
var (
|
||||
azureTDXHCLReportReader = readAzureHCLReport
|
||||
azureTDXReportDataWriter = writeAzureTDXReportData
|
||||
azureTDXHCLRefreshDelay = 3 * time.Second
|
||||
azureTDXIMDSQuoteURL = "http://169.254.169.254/acc/tdquote"
|
||||
)
|
||||
|
||||
// TDXQuoteFetcher fetches a raw TDX quote for the provided REPORT_DATA.
|
||||
type TDXQuoteFetcher interface {
|
||||
FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type TDXEvidenceFetcher interface {
|
||||
FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error)
|
||||
}
|
||||
|
||||
type defaultTDXQuoteFetcher struct{}
|
||||
|
||||
func (f defaultTDXQuoteFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
|
||||
evidence, err := f.FetchEvidence(reportData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return evidence.Quote, nil
|
||||
}
|
||||
|
||||
func (f defaultTDXQuoteFetcher) FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error) {
|
||||
hclReport, err := readFreshAzureTDXHCLReport(reportData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedReport, err := parseAzureHCLReport(hclReport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedReport.reportType != azureHCLReportTypeTDX {
|
||||
return nil, fmt.Errorf("Azure HCL report is not TDX")
|
||||
}
|
||||
if err := validateAzureTDXRuntimeClaimsHash(parsedReport); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
quote, err := DefaultAzureTDXIMDSClient.GetQuote(context.Background(), parsedReport.hwReport, http.DefaultClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get Azure TDX quote from IMDS: %w", err)
|
||||
}
|
||||
|
||||
return &azureTDXEvidence{
|
||||
Quote: quote,
|
||||
RuntimeData: append([]byte(nil), parsedReport.runtimeClaims...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DefaultTDXQuoteFetcher is used by the Azure TDX provider and is replaceable in tests.
|
||||
var DefaultTDXQuoteFetcher TDXQuoteFetcher = defaultTDXQuoteFetcher{}
|
||||
|
||||
// AzureTDXIMDSClient fetches an Azure TDX quote from the Azure Instance Metadata Service.
|
||||
type AzureTDXIMDSClient interface {
|
||||
GetQuote(ctx context.Context, tdReport []byte, client *http.Client) ([]byte, error)
|
||||
}
|
||||
|
||||
type defaultAzureTDXIMDSClient struct{}
|
||||
|
||||
// DefaultAzureTDXIMDSClient is used by the Azure TDX quote fetcher and is replaceable in tests.
|
||||
var DefaultAzureTDXIMDSClient AzureTDXIMDSClient = &defaultAzureTDXIMDSClient{}
|
||||
|
||||
// AzureTDXClient submits Azure TDX VM attestation evidence to Microsoft Azure Attestation.
|
||||
type AzureTDXClient interface {
|
||||
AttestTDXVM(ctx context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, client *http.Client) (string, error)
|
||||
}
|
||||
|
||||
type defaultAzureTDXClient struct{}
|
||||
|
||||
// DefaultAzureTDXClient is used by Azure TDX token fetching and is replaceable in tests.
|
||||
var DefaultAzureTDXClient AzureTDXClient = &defaultAzureTDXClient{}
|
||||
|
||||
type tdxAttestRequest struct {
|
||||
Quote string `json:"quote"`
|
||||
RuntimeData *tdxDataBlob `json:"runtimeData,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
type tdxDataBlob struct {
|
||||
Data string `json:"data"`
|
||||
DataType string `json:"dataType"`
|
||||
}
|
||||
|
||||
type tdxAttestResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type tdxIMDSQuoteRequest struct {
|
||||
Report string `json:"report"`
|
||||
}
|
||||
|
||||
type tdxIMDSQuoteResponse struct {
|
||||
Quote string `json:"quote"`
|
||||
}
|
||||
|
||||
type azureTDXEvidence struct {
|
||||
Quote []byte
|
||||
RuntimeData []byte
|
||||
}
|
||||
|
||||
type azureHCLReport struct {
|
||||
reportType uint32
|
||||
hashType uint32
|
||||
hwReport []byte
|
||||
runtimeClaims []byte
|
||||
}
|
||||
|
||||
func isAzureTDX() bool {
|
||||
if azureTDXHCLReportReader == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
hclReport, err := azureTDXHCLReportReader()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
parsedReport, err := parseAzureHCLReport(hclReport)
|
||||
return err == nil && parsedReport.reportType == azureHCLReportTypeTDX
|
||||
}
|
||||
|
||||
func fetchAzureTDXQuote(teeNonce []byte) ([]byte, error) {
|
||||
if teeNonce == nil {
|
||||
return nil, fmt.Errorf("tee nonce is required for Azure TDX attestation")
|
||||
}
|
||||
if len(teeNonce) != tdxabi.ReportDataSize {
|
||||
return nil, fmt.Errorf("invalid tee nonce length: expected %d bytes, got %d bytes", tdxabi.ReportDataSize, len(teeNonce))
|
||||
}
|
||||
|
||||
var reportData [tdxabi.ReportDataSize]byte
|
||||
copy(reportData[:], teeNonce)
|
||||
|
||||
evidence, err := fetchAzureTDXEvidence(reportData, teeNonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return evidence.Quote, nil
|
||||
}
|
||||
|
||||
func (c *defaultAzureTDXClient) AttestTDXVM(ctx context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, client *http.Client) (string, error) {
|
||||
if maaURL == "" {
|
||||
return "", fmt.Errorf("maaURL is empty")
|
||||
}
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
|
||||
maaURL, err := url.JoinPath(maaURL, tdxAttestEndpoint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parsing maaURL: %w", err)
|
||||
}
|
||||
maaURL += fmt.Sprintf("?api-version=%s", tdxAPIVersion)
|
||||
|
||||
attestRequest := tdxAttestRequest{
|
||||
Quote: base64.RawURLEncoding.EncodeToString(quote),
|
||||
}
|
||||
if len(runtimeData) > 0 {
|
||||
attestRequest.RuntimeData = &tdxDataBlob{
|
||||
Data: base64.RawURLEncoding.EncodeToString(runtimeData),
|
||||
DataType: tdxRuntimeDataType(runtimeData),
|
||||
}
|
||||
}
|
||||
if len(nonce) > 0 {
|
||||
attestRequest.Nonce = base64.RawURLEncoding.EncodeToString(nonce)
|
||||
}
|
||||
|
||||
reqBytes, err := json.Marshal(attestRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling TDX attestation request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, maaURL, bytes.NewReader(reqBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating TDX attestation request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("doing TDX attestation request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if msg, err := io.ReadAll(resp.Body); err == nil && len(msg) > 0 {
|
||||
return "", fmt.Errorf("MAA returned %v: %s", resp.Status, msg)
|
||||
}
|
||||
return "", fmt.Errorf("MAA returned %v", resp.Status)
|
||||
}
|
||||
|
||||
var attestResponse tdxAttestResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&attestResponse); err != nil {
|
||||
return "", fmt.Errorf("decoding TDX attestation response: %w", err)
|
||||
}
|
||||
if attestResponse.Token == "" {
|
||||
return "", fmt.Errorf("azure TDX attestation token not found in response")
|
||||
}
|
||||
|
||||
return attestResponse.Token, nil
|
||||
}
|
||||
|
||||
func (c *defaultAzureTDXIMDSClient) GetQuote(ctx context.Context, tdReport []byte, client *http.Client) ([]byte, error) {
|
||||
if len(tdReport) != azureTDReportSize {
|
||||
return nil, fmt.Errorf("invalid TD report length: expected %d bytes, got %d bytes", azureTDReportSize, len(tdReport))
|
||||
}
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
|
||||
quoteRequest := tdxIMDSQuoteRequest{
|
||||
Report: base64.RawURLEncoding.EncodeToString(tdReport),
|
||||
}
|
||||
|
||||
reqBytes, err := json.Marshal(quoteRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling Azure TDX IMDS quote request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, azureTDXIMDSQuoteURL, bytes.NewReader(reqBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating Azure TDX IMDS quote request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("doing Azure TDX IMDS quote request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if msg, err := io.ReadAll(resp.Body); err == nil && len(msg) > 0 {
|
||||
return nil, fmt.Errorf("Azure TDX IMDS returned %v: %s", resp.Status, msg)
|
||||
}
|
||||
return nil, fmt.Errorf("Azure TDX IMDS returned %v", resp.Status)
|
||||
}
|
||||
|
||||
var quoteResponse tdxIMDSQuoteResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode("eResponse); err != nil {
|
||||
return nil, fmt.Errorf("decoding Azure TDX IMDS quote response: %w", err)
|
||||
}
|
||||
if quoteResponse.Quote == "" {
|
||||
return nil, fmt.Errorf("Azure TDX IMDS quote not found in response")
|
||||
}
|
||||
|
||||
quote, err := decodeBase64URL(quoteResponse.Quote)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding Azure TDX IMDS quote: %w", err)
|
||||
}
|
||||
|
||||
return quote, nil
|
||||
}
|
||||
|
||||
// FetchAzureTDXAttestationToken fetches an Azure Attestation token for an Azure TDX VM.
|
||||
func FetchAzureTDXAttestationToken(tokenNonce []byte, maaURL string) ([]byte, error) {
|
||||
reportData := tdxReportDataFromRuntimeData(tokenNonce)
|
||||
evidence, err := fetchAzureTDXEvidence(reportData, tokenNonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Azure TDX quote: %w", err)
|
||||
}
|
||||
|
||||
token, err := DefaultAzureTDXClient.AttestTDXVM(context.Background(), evidence.Quote, evidence.RuntimeData, tokenNonce, maaURL, http.DefaultClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching azure TDX token: %w", err)
|
||||
}
|
||||
|
||||
return []byte(token), nil
|
||||
}
|
||||
|
||||
func tdxReportDataFromRuntimeData(runtimeData []byte) [tdxabi.ReportDataSize]byte {
|
||||
hash := sha256.Sum256(runtimeData)
|
||||
var reportData [tdxabi.ReportDataSize]byte
|
||||
copy(reportData[:sha256.Size], hash[:])
|
||||
return reportData
|
||||
}
|
||||
|
||||
func fetchAzureTDXEvidence(reportData [tdxabi.ReportDataSize]byte, fallbackRuntimeData []byte) (*azureTDXEvidence, error) {
|
||||
if evidenceFetcher, ok := DefaultTDXQuoteFetcher.(TDXEvidenceFetcher); ok {
|
||||
return evidenceFetcher.FetchEvidence(reportData)
|
||||
}
|
||||
|
||||
quote, err := DefaultTDXQuoteFetcher.FetchQuote(reportData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &azureTDXEvidence{
|
||||
Quote: quote,
|
||||
RuntimeData: append([]byte(nil), fallbackRuntimeData...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readFreshAzureTDXHCLReport(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
|
||||
if azureTDXReportDataWriter != nil {
|
||||
if err := azureTDXReportDataWriter(reportData[:]); err != nil {
|
||||
return nil, fmt.Errorf("writing Azure TDX report data: %w", err)
|
||||
}
|
||||
if azureTDXHCLRefreshDelay > 0 {
|
||||
time.Sleep(azureTDXHCLRefreshDelay)
|
||||
}
|
||||
}
|
||||
if azureTDXHCLReportReader == nil {
|
||||
return nil, fmt.Errorf("Azure TDX HCL report reader is not configured")
|
||||
}
|
||||
|
||||
hclReport, err := azureTDXHCLReportReader()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading Azure TDX HCL report: %w", err)
|
||||
}
|
||||
|
||||
return hclReport, nil
|
||||
}
|
||||
|
||||
func readAzureHCLReport() ([]byte, error) {
|
||||
tpm, err := tpm2.OpenTPM()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tpm.Close()
|
||||
|
||||
return tpm2.NVReadEx(tpm, azureHCLReportNVIndex, tpm2.HandleOwner, "", 0)
|
||||
}
|
||||
|
||||
func writeAzureTDXReportData(data []byte) error {
|
||||
if len(data) != tdxabi.ReportDataSize {
|
||||
return fmt.Errorf("invalid Azure TDX report data length: expected %d bytes, got %d bytes", tdxabi.ReportDataSize, len(data))
|
||||
}
|
||||
|
||||
tpm, err := tpm2.OpenTPM()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tpm.Close()
|
||||
|
||||
return writeAzureTDXReportDataToTPM(tpm, data)
|
||||
}
|
||||
|
||||
func writeAzureTDXReportDataToTPM(tpm io.ReadWriter, data []byte) error {
|
||||
if len(data) > int(^uint16(0)) {
|
||||
return fmt.Errorf("Azure TDX report data is too large")
|
||||
}
|
||||
|
||||
if err := ensureAzureTDXReportDataIndex(tpm, uint16(len(data))); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tpm2.NVWrite(tpm, tpm2.HandleOwner, azureHCLReportDataNVIndex, "", tpmutil.U16Bytes(data), 0); err != nil {
|
||||
return fmt.Errorf("writing Azure TDX report-data NV index: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureAzureTDXReportDataIndex(tpm io.ReadWriter, size uint16) error {
|
||||
pub, err := tpm2.NVReadPublic(tpm, azureHCLReportDataNVIndex)
|
||||
if err == nil {
|
||||
if pub.DataSize == size {
|
||||
return nil
|
||||
}
|
||||
if err := tpm2.NVUndefineSpace(tpm, "", tpm2.HandleOwner, azureHCLReportDataNVIndex); err != nil {
|
||||
return fmt.Errorf("undefining mismatched Azure TDX report-data NV index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
nvPub := tpm2.NVPublic{
|
||||
NVIndex: azureHCLReportDataNVIndex,
|
||||
NameAlg: tpm2.AlgSHA256,
|
||||
Attributes: tpm2.AttrOwnerWrite | tpm2.AttrOwnerRead,
|
||||
DataSize: size,
|
||||
}
|
||||
authArea := tpm2.AuthCommand{
|
||||
Session: tpm2.HandlePasswordSession,
|
||||
Attributes: tpm2.AttrContinueSession,
|
||||
Auth: []byte(""),
|
||||
}
|
||||
if err := tpm2.NVDefineSpaceEx(tpm, tpm2.HandleOwner, "", nvPub, authArea); err != nil {
|
||||
return fmt.Errorf("defining Azure TDX report-data NV index: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAzureHCLReport(report []byte) (*azureHCLReport, error) {
|
||||
minSize := azureHCLRuntimeDataOffset + azureHCLRuntimeClaimsOffset
|
||||
if len(report) < minSize {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report size: expected at least %d bytes, got %d bytes", minSize, len(report))
|
||||
}
|
||||
if string(report[:len(azureHCLSignature)]) != azureHCLSignature {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report signature")
|
||||
}
|
||||
if version := binary.LittleEndian.Uint32(report[4:8]); version != azureHCLVersion {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report version: expected %d, got %d", azureHCLVersion, version)
|
||||
}
|
||||
reportSize := binary.LittleEndian.Uint32(report[8:12])
|
||||
if reportSize > uint32(len(report)) {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report size: header reports %d bytes, got %d bytes", reportSize, len(report))
|
||||
}
|
||||
if requestType := binary.LittleEndian.Uint32(report[12:16]); requestType != azureHCLRequestType {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report request type: expected %d, got %d", azureHCLRequestType, requestType)
|
||||
}
|
||||
|
||||
runtimeData := report[azureHCLRuntimeDataOffset:]
|
||||
dataSize := binary.LittleEndian.Uint32(runtimeData[0:4])
|
||||
if dataSize < azureHCLRuntimeClaimsOffset {
|
||||
return nil, fmt.Errorf("invalid Azure HCL runtime data size: %d", dataSize)
|
||||
}
|
||||
if azureHCLRuntimeDataOffset+int(dataSize) > len(report) {
|
||||
return nil, fmt.Errorf("invalid Azure HCL runtime data size: header reports %d bytes", dataSize)
|
||||
}
|
||||
if version := binary.LittleEndian.Uint32(runtimeData[4:8]); version != azureHCLRuntimeDataVersion {
|
||||
return nil, fmt.Errorf("invalid Azure HCL runtime data version: expected %d, got %d", azureHCLRuntimeDataVersion, version)
|
||||
}
|
||||
|
||||
reportType := binary.LittleEndian.Uint32(runtimeData[8:12])
|
||||
if reportType != azureHCLReportTypeSNP && reportType != azureHCLReportTypeTDX {
|
||||
return nil, fmt.Errorf("invalid Azure HCL report type: %d", reportType)
|
||||
}
|
||||
|
||||
hashType := binary.LittleEndian.Uint32(runtimeData[12:16])
|
||||
claimsSize := binary.LittleEndian.Uint32(runtimeData[16:20])
|
||||
claimsEnd := azureHCLRuntimeClaimsOffset + int(claimsSize)
|
||||
if claimsEnd > int(dataSize) {
|
||||
return nil, fmt.Errorf("invalid Azure HCL runtime claims size: %d", claimsSize)
|
||||
}
|
||||
|
||||
hwReportSize := azureHCLMaxHWReportSize
|
||||
if reportType == azureHCLReportTypeTDX {
|
||||
hwReportSize = azureTDReportSize
|
||||
}
|
||||
if azureHCLHeaderSize+hwReportSize > len(report) {
|
||||
return nil, fmt.Errorf("invalid Azure HCL hardware report size: %d", hwReportSize)
|
||||
}
|
||||
|
||||
return &azureHCLReport{
|
||||
reportType: reportType,
|
||||
hashType: hashType,
|
||||
hwReport: append([]byte(nil), report[azureHCLHeaderSize:azureHCLHeaderSize+hwReportSize]...),
|
||||
runtimeClaims: append([]byte(nil), runtimeData[azureHCLRuntimeClaimsOffset:claimsEnd]...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateAzureTDXRuntimeClaimsHash(report *azureHCLReport) error {
|
||||
if report.hashType != azureHCLHashSHA256 {
|
||||
return fmt.Errorf("unsupported Azure HCL runtime data hash type: %d", report.hashType)
|
||||
}
|
||||
if len(report.hwReport) < azureTDReportDataOffset+sha256.Size {
|
||||
return fmt.Errorf("invalid Azure TDX TD report size: %d", len(report.hwReport))
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(report.runtimeClaims)
|
||||
if !bytes.Equal(hash[:], report.hwReport[azureTDReportDataOffset:azureTDReportDataOffset+sha256.Size]) {
|
||||
return fmt.Errorf("Azure TDX runtime claims hash does not match TD report data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func tdxRuntimeDataType(runtimeData []byte) string {
|
||||
if json.Valid(runtimeData) {
|
||||
return tdxRuntimeJSON
|
||||
}
|
||||
return tdxRuntimeBinary
|
||||
}
|
||||
|
||||
func decodeBase64URL(value string) ([]byte, error) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
return base64.URLEncoding.DecodeString(value)
|
||||
}
|
||||
|
||||
func verifyTDXQuoteWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
|
||||
decodedQuote, err := tdxabi.QuoteToProto(report)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse TDX quote: %w", err)
|
||||
}
|
||||
|
||||
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported TDX quote format")
|
||||
}
|
||||
|
||||
tdReport := quoteV4.GetTdQuoteBody()
|
||||
if tdReport == nil {
|
||||
return fmt.Errorf("missing TDX quote body")
|
||||
}
|
||||
|
||||
mrtd := tdReport.GetMrTd()
|
||||
if len(mrtd) == 0 {
|
||||
return fmt.Errorf("no MRTD in TDX quote")
|
||||
}
|
||||
|
||||
if err := matchMeasurementInCoRIM(manifest, mrtd); err != nil {
|
||||
return fmt.Errorf("%w for Azure TDX", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func matchMeasurementInCoRIM(manifest *corim.UnsignedCorim, measurement []byte) error {
|
||||
if manifest == nil || len(manifest.Tags) == 0 {
|
||||
return fmt.Errorf("no tags in CoRIM")
|
||||
}
|
||||
|
||||
for _, tag := range manifest.Tags {
|
||||
if !bytes.HasPrefix(tag, corim.ComidTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
tagValue := tag[len(corim.ComidTag):]
|
||||
|
||||
var c comid.Comid
|
||||
if err := c.FromCBOR(tagValue); err != nil {
|
||||
return fmt.Errorf("failed to parse CoMID: %w", err)
|
||||
}
|
||||
|
||||
if c.Triples.ReferenceValues == nil {
|
||||
continue
|
||||
}
|
||||
for _, rv := range *c.Triples.ReferenceValues {
|
||||
for _, m := range rv.Measurements {
|
||||
if m.Val.Digests == nil {
|
||||
continue
|
||||
}
|
||||
for _, digest := range *m.Val.Digests {
|
||||
if bytes.Equal(digest.HashValue, measurement) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("no matching reference value found in CoRIM")
|
||||
}
|
||||
|
||||
// AzureTDXMeasurementData contains the fields extracted from an Azure TDX attestation token
|
||||
// needed to construct a CoRIM policy for the TDX platform.
|
||||
type AzureTDXMeasurementData struct {
|
||||
MRTD string
|
||||
MRSEAM string
|
||||
RTMRs []string
|
||||
SEAMSVN uint64
|
||||
}
|
||||
|
||||
// ExtractAzureTDXMeasurement extracts core TDX measurements from an Azure Attestation token.
|
||||
func ExtractAzureTDXMeasurement(token string) (*AzureTDXMeasurementData, error) {
|
||||
claims, err := DefaultValidator.Validate(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate token: %w", err)
|
||||
}
|
||||
|
||||
mrtd, ok := azureClaimString(claims, "tdx_mrtd")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to get MRTD from claims")
|
||||
}
|
||||
|
||||
mrSeam, _ := azureClaimString(claims, "tdx_mrseam")
|
||||
|
||||
rtmrs := make([]string, 0, 4)
|
||||
for _, name := range []string{"tdx_rtmr0", "tdx_rtmr1", "tdx_rtmr2", "tdx_rtmr3"} {
|
||||
if value, ok := azureClaimString(claims, name); ok {
|
||||
rtmrs = append(rtmrs, value)
|
||||
}
|
||||
}
|
||||
|
||||
seamSVN, _ := azureClaimUint64(claims, "tdx_seamsvn")
|
||||
|
||||
return &AzureTDXMeasurementData{
|
||||
MRTD: mrtd,
|
||||
MRSEAM: mrSeam,
|
||||
RTMRs: rtmrs,
|
||||
SEAMSVN: seamSVN,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func azureClaimString(claims map[string]any, name string) (string, bool) {
|
||||
if value, ok := claims[name].(string); ok {
|
||||
return value, true
|
||||
}
|
||||
|
||||
tee, ok := claims["x-ms-isolation-tee"].(map[string]any)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
value, ok := tee[name].(string)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func azureClaimUint64(claims map[string]any, name string) (uint64, bool) {
|
||||
value, ok := claims[name]
|
||||
if !ok {
|
||||
tee, teeOK := claims["x-ms-isolation-tee"].(map[string]any)
|
||||
if !teeOK {
|
||||
return 0, false
|
||||
}
|
||||
value, ok = tee[name]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return uint64(typed), true
|
||||
case int:
|
||||
return uint64(typed), true
|
||||
case uint64:
|
||||
return typed, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,455 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
tdxabi "github.com/google/go-tdx-guest/abi"
|
||||
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
|
||||
tdxtestdata "github.com/google/go-tdx-guest/testing/testdata"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/veraison/corim/comid"
|
||||
"github.com/veraison/corim/corim"
|
||||
"github.com/veraison/swid"
|
||||
)
|
||||
|
||||
type mockTDXQuoteFetcher struct {
|
||||
quote []byte
|
||||
err error
|
||||
gotReportData [tdxabi.ReportDataSize]byte
|
||||
fetchQuoteCall bool
|
||||
}
|
||||
|
||||
func (m *mockTDXQuoteFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
|
||||
m.fetchQuoteCall = true
|
||||
m.gotReportData = reportData
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.quote, nil
|
||||
}
|
||||
|
||||
type mockTDXEvidenceFetcher struct {
|
||||
evidence *azureTDXEvidence
|
||||
err error
|
||||
gotReportData [tdxabi.ReportDataSize]byte
|
||||
fetchEvidenceHit bool
|
||||
}
|
||||
|
||||
func (m *mockTDXEvidenceFetcher) FetchQuote(reportData [tdxabi.ReportDataSize]byte) ([]byte, error) {
|
||||
evidence, err := m.FetchEvidence(reportData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return evidence.Quote, nil
|
||||
}
|
||||
|
||||
func (m *mockTDXEvidenceFetcher) FetchEvidence(reportData [tdxabi.ReportDataSize]byte) (*azureTDXEvidence, error) {
|
||||
m.fetchEvidenceHit = true
|
||||
m.gotReportData = reportData
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.evidence, nil
|
||||
}
|
||||
|
||||
type mockAzureTDXClient struct {
|
||||
token string
|
||||
err error
|
||||
gotQuote []byte
|
||||
gotRuntime []byte
|
||||
gotNonce []byte
|
||||
gotMaaURL string
|
||||
attestCalls int
|
||||
}
|
||||
|
||||
func (m *mockAzureTDXClient) AttestTDXVM(_ context.Context, quote []byte, runtimeData []byte, nonce []byte, maaURL string, _ *http.Client) (string, error) {
|
||||
m.attestCalls++
|
||||
m.gotQuote = append([]byte(nil), quote...)
|
||||
m.gotRuntime = append([]byte(nil), runtimeData...)
|
||||
m.gotNonce = append([]byte(nil), nonce...)
|
||||
m.gotMaaURL = maaURL
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
return m.token, nil
|
||||
}
|
||||
|
||||
type mockAzureTDXIMDSClient struct {
|
||||
quote []byte
|
||||
err error
|
||||
gotTDReport []byte
|
||||
getQuoteCall bool
|
||||
}
|
||||
|
||||
func (m *mockAzureTDXIMDSClient) GetQuote(_ context.Context, tdReport []byte, _ *http.Client) ([]byte, error) {
|
||||
m.getQuoteCall = true
|
||||
m.gotTDReport = append([]byte(nil), tdReport...)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.quote, nil
|
||||
}
|
||||
|
||||
func testAzureHCLReport(reportType uint32, runtimeClaims []byte) []byte {
|
||||
reportSize := azureHCLRuntimeDataOffset + azureHCLRuntimeClaimsOffset + len(runtimeClaims)
|
||||
hclReport := make([]byte, reportSize)
|
||||
copy(hclReport[:len(azureHCLSignature)], azureHCLSignature)
|
||||
binary.LittleEndian.PutUint32(hclReport[4:8], azureHCLVersion)
|
||||
binary.LittleEndian.PutUint32(hclReport[8:12], uint32(reportSize))
|
||||
binary.LittleEndian.PutUint32(hclReport[12:16], azureHCLRequestType)
|
||||
|
||||
runtimeData := hclReport[azureHCLRuntimeDataOffset:]
|
||||
binary.LittleEndian.PutUint32(runtimeData[0:4], uint32(azureHCLRuntimeClaimsOffset+len(runtimeClaims)))
|
||||
binary.LittleEndian.PutUint32(runtimeData[4:8], azureHCLRuntimeDataVersion)
|
||||
binary.LittleEndian.PutUint32(runtimeData[8:12], reportType)
|
||||
binary.LittleEndian.PutUint32(runtimeData[12:16], azureHCLHashSHA256)
|
||||
binary.LittleEndian.PutUint32(runtimeData[16:20], uint32(len(runtimeClaims)))
|
||||
copy(runtimeData[azureHCLRuntimeClaimsOffset:], runtimeClaims)
|
||||
|
||||
if reportType == azureHCLReportTypeTDX {
|
||||
hash := sha256.Sum256(runtimeClaims)
|
||||
copy(hclReport[azureHCLHeaderSize+azureTDReportDataOffset:], hash[:])
|
||||
}
|
||||
|
||||
return hclReport
|
||||
}
|
||||
|
||||
func TestProvider_TeeAttestation_AzureTDX(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
oldFetcher := DefaultTDXQuoteFetcher
|
||||
defer func() {
|
||||
azureTDXHCLReportReader = oldReader
|
||||
DefaultTDXQuoteFetcher = oldFetcher
|
||||
}()
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
|
||||
}
|
||||
fetcher := &mockTDXQuoteFetcher{quote: []byte("tdx-quote")}
|
||||
DefaultTDXQuoteFetcher = fetcher
|
||||
|
||||
reportData := bytes.Repeat([]byte{0xAB}, tdxabi.ReportDataSize)
|
||||
got, err := NewProvider().TeeAttestation(reportData)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("tdx-quote"), got)
|
||||
assert.True(t, fetcher.fetchQuoteCall)
|
||||
assert.Equal(t, reportData, fetcher.gotReportData[:])
|
||||
}
|
||||
|
||||
func TestProvider_TeeAttestation_AzureTDX_InvalidNonce(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
defer func() { azureTDXHCLReportReader = oldReader }()
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
|
||||
}
|
||||
|
||||
_, err := NewProvider().TeeAttestation([]byte("short"))
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid tee nonce length")
|
||||
}
|
||||
|
||||
func TestProvider_AzureAttestationToken_AzureTDX(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
oldFetcher := DefaultTDXQuoteFetcher
|
||||
oldClient := DefaultAzureTDXClient
|
||||
oldMaaURL := MaaURL
|
||||
defer func() {
|
||||
azureTDXHCLReportReader = oldReader
|
||||
DefaultTDXQuoteFetcher = oldFetcher
|
||||
DefaultAzureTDXClient = oldClient
|
||||
MaaURL = oldMaaURL
|
||||
}()
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
|
||||
}
|
||||
MaaURL = "https://tdx.example.attest.azure.net"
|
||||
|
||||
fetcher := &mockTDXQuoteFetcher{quote: []byte("quote")}
|
||||
client := &mockAzureTDXClient{token: "tdx-token"}
|
||||
DefaultTDXQuoteFetcher = fetcher
|
||||
DefaultAzureTDXClient = client
|
||||
|
||||
nonce := []byte("token-nonce")
|
||||
got, err := NewProvider().AzureAttestationToken(nonce)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("tdx-token"), got)
|
||||
|
||||
expectedReportData := tdxReportDataFromRuntimeData(nonce)
|
||||
assert.Equal(t, expectedReportData, fetcher.gotReportData)
|
||||
assert.Equal(t, []byte("quote"), client.gotQuote)
|
||||
assert.Equal(t, nonce, client.gotRuntime)
|
||||
assert.Equal(t, nonce, client.gotNonce)
|
||||
assert.Equal(t, MaaURL, client.gotMaaURL)
|
||||
}
|
||||
|
||||
func TestFetchAzureTDXAttestationToken_UsesHCLRuntimeClaims(t *testing.T) {
|
||||
oldFetcher := DefaultTDXQuoteFetcher
|
||||
oldClient := DefaultAzureTDXClient
|
||||
defer func() {
|
||||
DefaultTDXQuoteFetcher = oldFetcher
|
||||
DefaultAzureTDXClient = oldClient
|
||||
}()
|
||||
|
||||
runtimeClaims := []byte(`{"keys":[],"user-data":"nonce"}`)
|
||||
fetcher := &mockTDXEvidenceFetcher{
|
||||
evidence: &azureTDXEvidence{
|
||||
Quote: []byte("quote"),
|
||||
RuntimeData: runtimeClaims,
|
||||
},
|
||||
}
|
||||
client := &mockAzureTDXClient{token: "tdx-token"}
|
||||
DefaultTDXQuoteFetcher = fetcher
|
||||
DefaultAzureTDXClient = client
|
||||
|
||||
nonce := []byte("token-nonce")
|
||||
got, err := FetchAzureTDXAttestationToken(nonce, "https://tdx.example.attest.azure.net")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("tdx-token"), got)
|
||||
assert.True(t, fetcher.fetchEvidenceHit)
|
||||
assert.Equal(t, tdxReportDataFromRuntimeData(nonce), fetcher.gotReportData)
|
||||
assert.Equal(t, runtimeClaims, client.gotRuntime)
|
||||
assert.Equal(t, nonce, client.gotNonce)
|
||||
}
|
||||
|
||||
func TestFetchAzureTDXAttestationToken_FetchQuoteError(t *testing.T) {
|
||||
oldFetcher := DefaultTDXQuoteFetcher
|
||||
defer func() { DefaultTDXQuoteFetcher = oldFetcher }()
|
||||
|
||||
DefaultTDXQuoteFetcher = &mockTDXQuoteFetcher{err: fmt.Errorf("quote unavailable")}
|
||||
|
||||
_, err := FetchAzureTDXAttestationToken([]byte("nonce"), "https://tdx.example.attest.azure.net")
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to fetch Azure TDX quote")
|
||||
}
|
||||
|
||||
func TestDefaultTDXQuoteFetcher_FetchEvidence_AzureHCLIMDS(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
oldWriter := azureTDXReportDataWriter
|
||||
oldDelay := azureTDXHCLRefreshDelay
|
||||
oldIMDSClient := DefaultAzureTDXIMDSClient
|
||||
defer func() {
|
||||
azureTDXHCLReportReader = oldReader
|
||||
azureTDXReportDataWriter = oldWriter
|
||||
azureTDXHCLRefreshDelay = oldDelay
|
||||
DefaultAzureTDXIMDSClient = oldIMDSClient
|
||||
}()
|
||||
|
||||
runtimeClaims := []byte(`{"keys":[],"user-data":"fresh"}`)
|
||||
hclReport := testAzureHCLReport(azureHCLReportTypeTDX, runtimeClaims)
|
||||
var gotReportData []byte
|
||||
azureTDXReportDataWriter = func(data []byte) error {
|
||||
gotReportData = append([]byte(nil), data...)
|
||||
return nil
|
||||
}
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return hclReport, nil
|
||||
}
|
||||
azureTDXHCLRefreshDelay = 0
|
||||
imdsClient := &mockAzureTDXIMDSClient{quote: []byte("tdx-quote")}
|
||||
DefaultAzureTDXIMDSClient = imdsClient
|
||||
|
||||
reportData := [tdxabi.ReportDataSize]byte{0xAB}
|
||||
evidence, err := defaultTDXQuoteFetcher{}.FetchEvidence(reportData)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, reportData[:], gotReportData)
|
||||
assert.Equal(t, []byte("tdx-quote"), evidence.Quote)
|
||||
assert.Equal(t, runtimeClaims, evidence.RuntimeData)
|
||||
assert.True(t, imdsClient.getQuoteCall)
|
||||
assert.Len(t, imdsClient.gotTDReport, azureTDReportSize)
|
||||
}
|
||||
|
||||
func TestDefaultAzureTDXClient_AttestTDXVM(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "/attest/TdxVm", r.URL.Path)
|
||||
assert.Equal(t, tdxAPIVersion, r.URL.Query().Get("api-version"))
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
|
||||
var req tdxAttestRequest
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("quote")), req.Quote)
|
||||
require.NotNil(t, req.RuntimeData)
|
||||
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("runtime")), req.RuntimeData.Data)
|
||||
assert.Equal(t, tdxRuntimeBinary, req.RuntimeData.DataType)
|
||||
assert.Equal(t, base64.RawURLEncoding.EncodeToString([]byte("nonce")), req.Nonce)
|
||||
|
||||
_, _ = w.Write([]byte(`{"token":"tdx-token"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
token, err := (&defaultAzureTDXClient{}).AttestTDXVM(
|
||||
context.Background(),
|
||||
[]byte("quote"),
|
||||
[]byte("runtime"),
|
||||
[]byte("nonce"),
|
||||
server.URL,
|
||||
server.Client(),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "tdx-token", token)
|
||||
}
|
||||
|
||||
func TestDefaultAzureTDXClient_AttestTDXVM_JSONRuntimeData(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req tdxAttestRequest
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
require.NotNil(t, req.RuntimeData)
|
||||
assert.Equal(t, tdxRuntimeJSON, req.RuntimeData.DataType)
|
||||
|
||||
_, _ = w.Write([]byte(`{"token":"tdx-token"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := (&defaultAzureTDXClient{}).AttestTDXVM(
|
||||
context.Background(),
|
||||
[]byte("quote"),
|
||||
[]byte(`{"keys":[]}`),
|
||||
[]byte("nonce"),
|
||||
server.URL,
|
||||
server.Client(),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDefaultAzureTDXClient_AttestTDXVM_ErrorStatus(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad quote", http.StatusBadRequest)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := (&defaultAzureTDXClient{}).AttestTDXVM(context.Background(), []byte("quote"), nil, nil, server.URL, server.Client())
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "MAA returned 400 Bad Request")
|
||||
}
|
||||
|
||||
func TestDefaultAzureTDXIMDSClient_GetQuote(t *testing.T) {
|
||||
oldURL := azureTDXIMDSQuoteURL
|
||||
defer func() { azureTDXIMDSQuoteURL = oldURL }()
|
||||
|
||||
tdReport := bytes.Repeat([]byte{0xA5}, azureTDReportSize)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
|
||||
var req tdxIMDSQuoteRequest
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, base64.RawURLEncoding.EncodeToString(tdReport), req.Report)
|
||||
|
||||
resp := tdxIMDSQuoteResponse{
|
||||
Quote: base64.RawURLEncoding.EncodeToString([]byte("quote")),
|
||||
}
|
||||
require.NoError(t, json.NewEncoder(w).Encode(resp))
|
||||
}))
|
||||
defer server.Close()
|
||||
azureTDXIMDSQuoteURL = server.URL
|
||||
|
||||
quote, err := (&defaultAzureTDXIMDSClient{}).GetQuote(context.Background(), tdReport, server.Client())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("quote"), quote)
|
||||
}
|
||||
|
||||
func TestIsAzureTDX_UsesHCLReportType(t *testing.T) {
|
||||
oldReader := azureTDXHCLReportReader
|
||||
defer func() { azureTDXHCLReportReader = oldReader }()
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeSNP, []byte("runtime")), nil
|
||||
}
|
||||
assert.False(t, isAzureTDX())
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return testAzureHCLReport(azureHCLReportTypeTDX, []byte(`{"keys":[]}`)), nil
|
||||
}
|
||||
assert.True(t, isAzureTDX())
|
||||
|
||||
azureTDXHCLReportReader = func() ([]byte, error) {
|
||||
return nil, fmt.Errorf("no vTPM")
|
||||
}
|
||||
assert.False(t, isAzureTDX())
|
||||
}
|
||||
|
||||
func TestExtractAzureTDXMeasurement_Success(t *testing.T) {
|
||||
oldValidator := DefaultValidator
|
||||
defer func() { DefaultValidator = oldValidator }()
|
||||
|
||||
DefaultValidator = &mockTokenValidator{
|
||||
validateFunc: func(token string) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"tdx_mrtd": "mrtd",
|
||||
"tdx_mrseam": "mrseam",
|
||||
"tdx_rtmr0": "rtmr0",
|
||||
"tdx_rtmr1": "rtmr1",
|
||||
"tdx_rtmr2": "rtmr2",
|
||||
"tdx_rtmr3": "rtmr3",
|
||||
"tdx_seamsvn": float64(7),
|
||||
"unrelatedKey": "ignored",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
data, err := ExtractAzureTDXMeasurement("valid-token")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &AzureTDXMeasurementData{
|
||||
MRTD: "mrtd",
|
||||
MRSEAM: "mrseam",
|
||||
RTMRs: []string{"rtmr0", "rtmr1", "rtmr2", "rtmr3"},
|
||||
SEAMSVN: 7,
|
||||
}, data)
|
||||
}
|
||||
|
||||
func TestVerifier_VerifyWithCoRIM_AzureTDX(t *testing.T) {
|
||||
decodedQuote, err := tdxabi.QuoteToProto(tdxtestdata.RawQuote)
|
||||
require.NoError(t, err)
|
||||
|
||||
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
|
||||
require.True(t, ok)
|
||||
mrtd := quoteV4.GetTdQuoteBody().GetMrTd()
|
||||
|
||||
c := comid.NewComid()
|
||||
c.SetTagIdentity("tdx-tag", 0)
|
||||
|
||||
m := comid.MustNewUintMeasurement(uint64(1))
|
||||
m.AddDigest(swid.Sha384, mrtd)
|
||||
m.SetRawValueBytes([]byte("raw"), nil)
|
||||
|
||||
rv := comid.ReferenceValue{
|
||||
Environment: comid.Environment{
|
||||
Class: comid.NewClassOID("1.2.3.4"),
|
||||
},
|
||||
Measurements: comid.Measurements{*m},
|
||||
}
|
||||
c.AddReferenceValue(rv)
|
||||
|
||||
manifest := corim.NewUnsignedCorim()
|
||||
manifest.SetID("test-tdx-corim")
|
||||
manifest.AddComid(*c)
|
||||
|
||||
err = NewVerifier(&bytes.Buffer{}).VerifyWithCoRIM(tdxtestdata.RawQuote, manifest)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -55,6 +55,8 @@ func (c *client) GetAttestation(ctx context.Context, reportData [64]byte, nonce
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
|
||||
case attestation.SNPvTPM:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
|
||||
case attestation.Azure:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_AZURE
|
||||
default:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
|
||||
}
|
||||
@@ -92,6 +94,8 @@ func (c *client) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
|
||||
case attestation.SNPvTPM:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
|
||||
case attestation.Azure:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_AZURE
|
||||
default:
|
||||
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
|
||||
}
|
||||
|
||||
@@ -174,6 +174,40 @@ func TestGetAttestationTDX(t *testing.T) {
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetAttestationAzure tests getting Azure attestation.
|
||||
func TestGetAttestationAzure(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-azure-evidence.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.Azure)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, quote)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_AZURE, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetAttestationVTPM tests getting vTPM attestation.
|
||||
func TestGetAttestationVTPM(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
@@ -479,6 +513,40 @@ func TestGetRawEvidenceTDX(t *testing.T) {
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetRawEvidenceAzure tests getting raw evidence for Azure platform.
|
||||
func TestGetRawEvidenceAzure(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "raw-evidence-azure.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.Azure)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, evidence)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_AZURE, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetRawEvidenceVTPM tests getting raw evidence for VTPM platform.
|
||||
func TestGetRawEvidenceVTPM(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
Reference in New Issue
Block a user