diff --git a/pkg/atls/ea/authenticator_test.go b/pkg/atls/ea/authenticator_test.go index 2d10fbdc..50aecf4b 100644 --- a/pkg/atls/ea/authenticator_test.go +++ b/pkg/atls/ea/authenticator_test.go @@ -20,7 +20,7 @@ import ( attestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation" ) -func selfSignedCert(t *testing.T) tls.Certificate { +func selfSignedCert(t testing.TB) tls.Certificate { t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -42,7 +42,7 @@ func selfSignedCert(t *testing.T) tls.Certificate { return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: priv} } -func tlsPair(t *testing.T, cert tls.Certificate) (srv, cli *tls.Conn) { +func tlsPair(t testing.TB, cert tls.Certificate) (srv, cli *tls.Conn) { t.Helper() srvConf := &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13} cliConf := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13} diff --git a/pkg/atls/internal_transport/conn_test.go b/pkg/atls/internal_transport/conn_test.go index 86435026..e8bd4b7b 100644 --- a/pkg/atls/internal_transport/conn_test.go +++ b/pkg/atls/internal_transport/conn_test.go @@ -16,7 +16,7 @@ import ( "time" ) -func selfSignedCert(t *testing.T) tls.Certificate { +func selfSignedCert(t testing.TB) tls.Certificate { t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) diff --git a/pkg/attestation/eat/extractor.go b/pkg/attestation/eat/extractor.go index a213cef1..ac8c26ee 100644 --- a/pkg/attestation/eat/extractor.go +++ b/pkg/attestation/eat/extractor.go @@ -5,11 +5,16 @@ package eat import ( "encoding/binary" + "encoding/hex" "fmt" "github.com/google/go-sev-guest/abi" + sevsnppb "github.com/google/go-sev-guest/proto/sevsnp" tdxabi "github.com/google/go-tdx-guest/abi" tdxpb "github.com/google/go-tdx-guest/proto/tdx" + attestpb "github.com/google/go-tpm-tools/proto/attest" + tpmpb "github.com/google/go-tpm-tools/proto/tpm" + "google.golang.org/protobuf/proto" ) // OEMID constants (Private Enterprise Numbers). @@ -20,18 +25,45 @@ const ( ) // extractSNPClaims extracts AMD SEV-SNP specific claims from binary report. +// report may be one of three formats: +// 1. proto-marshaled sevsnp.Attestation (SNP-only platform, from fetchSEVAttestation) +// 2. proto-marshaled attest.Attestation (go-tpm-tools, SNP_VTPM platform, field 7 = SevSnpAttestation) +// 3. raw binary SNP report (0x4A0 bytes) func extractSNPClaims(claims *EATClaims, report []byte) error { + // Try sevsnp.Attestation (SNP-only proto format). + var sevAttest sevsnppb.Attestation + if err := proto.Unmarshal(report, &sevAttest); err == nil { + if r := sevAttest.GetReport(); r != nil { + return populateSNPClaims(claims, r) + } + } + + // Try attest.Attestation (go-tpm-tools SNP_VTPM format). + var tpmAttest attestpb.Attestation + if err := proto.Unmarshal(report, &tpmAttest); err == nil { + if snp := tpmAttest.GetSevSnpAttestation(); snp != nil { + if r := snp.GetReport(); r != nil { + if err := populateSNPClaims(claims, r); err != nil { + return err + } + populateVTPMClaims(claims, tpmAttest.GetQuotes(), tpmAttest.GetEventLog()) + return nil + } + } + } + + // Fall back to raw binary SNP report. if len(report) < int(abi.ReportSize) { return fmt.Errorf("SNP report too small: got %d bytes, want at least %d", len(report), abi.ReportSize) } - - // Parse SNP report structure snpReport, err := abi.ReportToProto(report[:abi.ReportSize]) if err != nil { return fmt.Errorf("failed to parse SNP report: %w", err) } + return populateSNPClaims(claims, snpReport) +} - // Extract SNP-specific fields +func populateSNPClaims(claims *EATClaims, snpReport *sevsnppb.Report) error { claims.SNPExtensions = &SNPExtensions{ Measurement: snpReport.Measurement, Policy: snpReport.Policy, @@ -42,23 +74,16 @@ func extractSNPClaims(claims *EATClaims, report []byte) error { PlatformInfo: snpReport.PlatformInfo, ChipID: snpReport.ChipId, } - - // Set TCB version info claims.SNPExtensions.CurrentTCB = snpReport.CurrentTcb claims.SNPExtensions.ReportedTCB = snpReport.ReportedTcb claims.SNPExtensions.CommittedTCB = snpReport.CommittedTcb claims.SNPExtensions.LaunchTCB = snpReport.LaunchTcb claims.SNPExtensions.TCB = fmt.Sprintf("current:%d,reported:%d", snpReport.CurrentTcb, snpReport.ReportedTcb) - - // Set core EAT claims from SNP report - claims.Measurements = snpReport.Measurement - claims.UEID = snpReport.ChipId // Use ChipID as UEID - claims.OEMID = OEMID_AMD // AMD's PEN (Private Enterprise Number) claims.SNPExtensions.Signature = snpReport.Signature - - // Set hardware model (hash of product name) + claims.Measurements = snpReport.Measurement + claims.UEID = snpReport.ChipId + claims.OEMID = OEMID_AMD claims.HWModel = []byte(fmt.Sprintf("SEV-SNP-%d", snpReport.Version)) - return nil } @@ -122,18 +147,77 @@ func extractTDXClaims(claims *EATClaims, report []byte) error { return nil } -// extractVTPMClaims extracts vTPM specific claims from binary report. -func extractVTPMClaims(claims *EATClaims, report []byte) error { - // vTPM report is typically a marshaled structure containing PCRs and quote - // For now, store the entire report as the quote - claims.VTPMExtensions = &VTPMExtensions{ - Quote: report, - PCRs: make(map[string]string), +// populateVTPMClaims fills VTPMExtensions from go-tpm-tools quote banks and event log. +// For SNP_VTPM the SNP measurement is already set; this adds PCR values alongside it. +// PCR keys are formatted as ":" (e.g. "sha256:0"), values are hex-encoded. +// The raw TPMS_ATTEST bytes from the SHA-256 bank are stored as the canonical Quote. +func populateVTPMClaims(claims *EATClaims, quotes []*tpmpb.Quote, eventLog []byte) { + vtpm := &VTPMExtensions{ + PCRs: make(map[string]string), + EventLog: eventLog, } - // Set core EAT claims - claims.Measurements = report[:32] // Use first 32 bytes as measurement - claims.UEID = report[:16] // Use first 16 bytes as UEID + for _, q := range quotes { + if q == nil { + continue + } + pcrs := q.GetPcrs() + if pcrs == nil { + continue + } + + // Prefer the SHA-256 bank as the canonical raw quote; fall back to the first available. + if pcrs.GetHash() == tpmpb.HashAlgo_SHA256 || vtpm.Quote == nil { + if raw := q.GetQuote(); len(raw) > 0 { + vtpm.Quote = raw + } + } + + hashName := tpmHashName(pcrs.GetHash()) + for idx, val := range pcrs.GetPcrs() { + vtpm.PCRs[fmt.Sprintf("%s:%d", hashName, idx)] = hex.EncodeToString(val) + } + } + + claims.VTPMExtensions = vtpm +} + +func tpmHashName(h tpmpb.HashAlgo) string { + switch h { + case tpmpb.HashAlgo_SHA1: + return "sha1" + case tpmpb.HashAlgo_SHA256: + return "sha256" + case tpmpb.HashAlgo_SHA384: + return "sha384" + case tpmpb.HashAlgo_SHA512: + return "sha512" + default: + return fmt.Sprintf("hash%d", int(h)) + } +} + +// extractVTPMClaims extracts vTPM specific claims from a proto-marshaled attest.Attestation. +func extractVTPMClaims(claims *EATClaims, report []byte) error { + var tpmAttest attestpb.Attestation + if err := proto.Unmarshal(report, &tpmAttest); err != nil { + return fmt.Errorf("failed to parse vTPM attestation: %w", err) + } + + populateVTPMClaims(claims, tpmAttest.GetQuotes(), tpmAttest.GetEventLog()) + + // Use PCR0 (SHA-256) as the canonical measurement if present. + if ext := claims.VTPMExtensions; ext != nil { + if v, ok := ext.PCRs["sha256:0"]; ok { + b, err := hex.DecodeString(v) + if err == nil { + claims.Measurements = b + if len(b) >= 16 { + claims.UEID = b[:16] + } + } + } + } return nil } diff --git a/pkg/attestation/eat/extractor_test.go b/pkg/attestation/eat/extractor_test.go index f212f849..a1d938ca 100644 --- a/pkg/attestation/eat/extractor_test.go +++ b/pkg/attestation/eat/extractor_test.go @@ -9,8 +9,11 @@ import ( "testing" "github.com/google/go-sev-guest/abi" + attestpb "github.com/google/go-tpm-tools/proto/attest" + tpmpb "github.com/google/go-tpm-tools/proto/tpm" "github.com/stretchr/testify/assert" "github.com/ultravioletrs/cocos/pkg/attestation" + "google.golang.org/protobuf/proto" ) func TestExtractSNPClaims(t *testing.T) { @@ -99,16 +102,29 @@ func TestTDXExtensionsJSON(t *testing.T) { } func TestExtractVTPMClaims(t *testing.T) { - report := make([]byte, 32) - copy(report, []byte("vtpm-report-with-enough-length-123")) + pcr0 := []byte("0123456789abcdef0123456789abcdef") + rawQuote := []byte("raw-vtpm-quote") + report, err := proto.Marshal(&attestpb.Attestation{ + Quotes: []*tpmpb.Quote{{ + Quote: rawQuote, + Pcrs: &tpmpb.PCRs{ + Hash: tpmpb.HashAlgo_SHA256, + Pcrs: map[uint32][]byte{ + 0: pcr0, + }, + }, + }}, + EventLog: []byte("event-log"), + }) + assert.NoError(t, err) claims := &EATClaims{} - err := extractVTPMClaims(claims, report) + err = extractVTPMClaims(claims, report) assert.NoError(t, err) assert.NotNil(t, claims.VTPMExtensions) - assert.Equal(t, report, claims.VTPMExtensions.Quote) - assert.Equal(t, report[:32], claims.Measurements) - assert.Equal(t, report[:16], claims.UEID) + assert.Equal(t, rawQuote, claims.VTPMExtensions.Quote) + assert.Equal(t, pcr0, claims.Measurements) + assert.Equal(t, pcr0[:16], claims.UEID) } func TestExtractAzureClaims(t *testing.T) {