NOISSUE - Post-handshake aTLS (#582)
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled

* initial post-handshake aTLS implementation

* add header

* rebased

* remove grpc.go and http.go

* fix authenticator issues

* add freshness nonce

---------

Co-authored-by: ultraviolet <cocosai@worker-52.local.pragmatic-it.com>
Co-authored-by: ultraviolet <cocosai@k8s-master.local.pragmatic-it.com>
This commit is contained in:
Danko Miladinovic
2026-03-26 16:57:09 +01:00
committed by GitHub
parent 42b05524c8
commit 80bf813c48
45 changed files with 3716 additions and 2325 deletions
-69
View File
@@ -1,69 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"encoding/asn1"
"encoding/hex"
"fmt"
)
const (
defaultNotAfterYears = 1
nonceLength = 64
nonceSuffix = ".nonce"
)
// Platform-specific OIDs for certificate extensions.
var (
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
)
// CertificateSubject contains certificate subject information.
type CertificateSubject struct {
Organization string
CommonName string
Country string
Province string
Locality string
StreetAddress string
PostalCode string
}
// DefaultCertificateSubject returns the default certificate subject for Ultraviolet.
func DefaultCertificateSubject() CertificateSubject {
return CertificateSubject{
Organization: "Ultraviolet",
CommonName: "Ultraviolet",
Country: "Serbia",
Province: "",
Locality: "Belgrade",
StreetAddress: "Bulevar Arsenija Carnojevica 103",
PostalCode: "11000",
}
}
func extractNonceFromSNI(serverName string) ([]byte, error) {
if len(serverName) < len(nonceSuffix) || !hasNonceSuffix(serverName) {
return nil, fmt.Errorf("invalid server name: %s", serverName)
}
nonceStr := serverName[:len(serverName)-len(nonceSuffix)]
nonce, err := hex.DecodeString(nonceStr)
if err != nil {
return nil, fmt.Errorf("failed to decode nonce: %w", err)
}
if len(nonce) != nonceLength {
return nil, fmt.Errorf("invalid nonce length: expected %d bytes, got %d bytes", nonceLength, len(nonce))
}
return nonce, nil
}
func hasNonceSuffix(serverName string) bool {
return len(serverName) >= len(nonceSuffix) &&
serverName[len(serverName)-len(nonceSuffix):] == nonceSuffix
}
File diff suppressed because it is too large Load Diff
-82
View File
@@ -1,82 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"context"
"encoding/asn1"
"fmt"
"github.com/ultravioletrs/cocos/pkg/attestation"
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
"golang.org/x/crypto/sha3"
)
// AttestationProvider defines the interface for platform attestation operations.
type AttestationProvider interface {
Attest(pubKey []byte, nonce []byte) ([]byte, error)
OID() asn1.ObjectIdentifier
PlatformType() attestation.PlatformType
}
// PlatformAttestationProvider handles platform attestation operations.
type platformAttestationProvider struct {
attClient attestation_client.Client
oid asn1.ObjectIdentifier
platformType attestation.PlatformType
}
// NewAttestationProvider creates a new attestation provider for the given platform type.
func NewAttestationProvider(attClient attestation_client.Client, platformType attestation.PlatformType) (AttestationProvider, error) {
oid, err := OID(platformType)
if err != nil {
return nil, fmt.Errorf("failed to get OID: %w", err)
}
return &platformAttestationProvider{
attClient: attClient,
oid: oid,
platformType: platformType,
}, nil
}
func (p *platformAttestationProvider) Attest(pubKey []byte, nonce []byte) ([]byte, error) {
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
var reportData [64]byte
copy(reportData[:], hashNonce[:])
var nonceArray [32]byte
copy(nonceArray[:], hashNonce[:32])
// Get signed EAT token from attestation service
// The attestation service maintains a persistent signing key and returns a pre-signed token
eatToken, err := p.attClient.GetAttestation(context.Background(), reportData, nonceArray, p.platformType)
if err != nil {
return nil, fmt.Errorf("failed to get attestation from service: %w", err)
}
return eatToken, nil
}
func (p *platformAttestationProvider) OID() asn1.ObjectIdentifier {
return p.oid
}
func (p *platformAttestationProvider) PlatformType() attestation.PlatformType {
return p.platformType
}
func OID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) {
switch platformType {
case attestation.SNPvTPM:
return SNPvTPMOID, nil
case attestation.Azure:
return AzureOID, nil
case attestation.TDX:
return TDXOID, nil
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
}
-190
View File
@@ -1,190 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"strings"
"time"
"github.com/absmach/certs"
sdk "github.com/absmach/certs/sdk"
"github.com/ultravioletrs/cocos/pkg/attestation"
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
)
// CertificateProvider defines the interface for providing TLS certificates.
type CertificateProvider interface {
GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
// AttestedCertificateProvider provides attested TLS certificates.
type attestedCertificateProvider struct {
attestationProvider AttestationProvider
certsSDK sdk.SDK
agentToken string
subject CertificateSubject
useCA bool
cvmID string
ttl time.Duration
notAfterYears int
}
// NewAttestedProvider creates a new attested certificate provider for self-signed certificates.
func NewAttestedProvider(
attestationProvider AttestationProvider,
subject CertificateSubject,
) CertificateProvider {
return &attestedCertificateProvider{
attestationProvider: attestationProvider,
subject: subject,
useCA: false,
notAfterYears: defaultNotAfterYears,
}
}
// NewAttestedCAProvider creates a new attested certificate provider for CA-signed certificates.
func NewAttestedCAProvider(
attestationProvider AttestationProvider,
subject CertificateSubject,
certsSDK sdk.SDK, cvmID, agentToken string,
) CertificateProvider {
return &attestedCertificateProvider{
attestationProvider: attestationProvider,
subject: subject,
certsSDK: certsSDK,
agentToken: agentToken,
useCA: true,
cvmID: cvmID,
ttl: time.Hour * 24 * 365, // Default 1 year
}
}
// SetTTL sets the certificate TTL for CA-signed certificates.
func (p *attestedCertificateProvider) SetTTL(ttl time.Duration) {
p.ttl = ttl
}
func (p *attestedCertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
nonce, err := extractNonceFromSNI(clientHello.ServerName)
if err != nil {
return nil, fmt.Errorf("failed to extract nonce: %w", err)
}
attestationData, err := p.attestationProvider.Attest(pubKeyDER, nonce)
if err != nil {
return nil, fmt.Errorf("failed to get attestation: %w", err)
}
extension := pkix.Extension{
Id: p.attestationProvider.OID(),
Value: attestationData,
}
var certDERBytes []byte
if p.useCA {
certDERBytes, err = p.generateCASignedCertificate(clientHello.Context(), privateKey, extension)
} else {
certDERBytes, err = p.generateSelfSignedCertificate(privateKey, extension)
}
if err != nil {
return nil, fmt.Errorf("failed to generate certificate: %w", err)
}
return &tls.Certificate{
Certificate: [][]byte{certDERBytes},
PrivateKey: privateKey,
}, nil
}
func (p *attestedCertificateProvider) generateSelfSignedCertificate(privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) {
certTemplate := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().Unix()),
Subject: pkix.Name{
Organization: []string{p.subject.Organization},
Country: []string{p.subject.Country},
Province: []string{p.subject.Province},
Locality: []string{p.subject.Locality},
StreetAddress: []string{p.subject.StreetAddress},
PostalCode: []string{p.subject.PostalCode},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(p.notAfterYears, 0, 0),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
ExtraExtensions: []pkix.Extension{extension},
}
return x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
}
func (p *attestedCertificateProvider) generateCASignedCertificate(ctx context.Context, privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) {
csrMetadata := certs.CSRMetadata{
Organization: []string{p.subject.Organization},
Country: []string{p.subject.Country},
CommonName: p.subject.CommonName,
Province: []string{p.subject.Province},
Locality: []string{p.subject.Locality},
StreetAddress: []string{p.subject.StreetAddress},
PostalCode: []string{p.subject.PostalCode},
ExtraExtensions: []pkix.Extension{extension},
}
csr, sdkerr := p.certsSDK.CreateCSR(ctx, csrMetadata, privateKey)
if sdkerr != nil {
return nil, fmt.Errorf("failed to create CSR: %w", sdkerr)
}
cert, err := p.certsSDK.IssueFromCSRInternal(ctx, p.cvmID, p.ttl.String(), string(csr.CSR), p.agentToken)
if err != nil {
return nil, err
}
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
block, rest := pem.Decode([]byte(cleanCertificateString))
if len(rest) != 0 {
return nil, fmt.Errorf("failed to decode certificate PEM: unexpected remaining data")
}
if block == nil {
return nil, fmt.Errorf("failed to decode certificate PEM: no PEM block found")
}
return block.Bytes, nil
}
func NewProvider(attClient attestation_client.Client, platformType attestation.PlatformType, agentToken, cvmID string, certsSDK sdk.SDK) (CertificateProvider, error) {
attestationProvider, err := NewAttestationProvider(attClient, platformType)
if err != nil {
return nil, fmt.Errorf("failed to create attestation provider: %w", err)
}
subject := DefaultCertificateSubject()
if certsSDK != nil {
return NewAttestedCAProvider(attestationProvider, subject, certsSDK, cvmID, agentToken), nil
}
return NewAttestedProvider(attestationProvider, subject), nil
}
-210
View File
@@ -1,210 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"crypto/x509"
"encoding/asn1"
"fmt"
"log/slog"
"os"
"time"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/veraison/corim/corim"
"golang.org/x/crypto/sha3"
)
type CertificateVerifier interface {
VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte) error
}
// CertificateVerifier handles certificate verification operations.
type certificateVerifier struct {
rootCAs *x509.CertPool
verifierProvider func(attestation.PlatformType) (attestation.Verifier, error)
}
func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier {
return &certificateVerifier{
rootCAs: rootCAs,
verifierProvider: platformVerifier,
}
}
func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error {
slog.Debug("Starting peer certificate verification for aTLS")
if len(rawCerts) == 0 {
err := fmt.Errorf("no certificates provided")
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
err = fmt.Errorf("failed to parse x509 certificate: %w", err)
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
slog.Debug("Successfully parsed peer x509 certificate", "subject", cert.Subject.String())
if err := v.verifyCertificateSignature(cert); err != nil {
err = fmt.Errorf("certificate signature verification failed: %w", err)
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
slog.Debug("Successfully verified peer certificate signature")
err = v.verifyAttestationExtension(cert, nonce)
if err != nil {
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
slog.Debug("Successfully verified aTLS attestation extension")
return nil
}
func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) error {
rootCAs := v.rootCAs
if rootCAs == nil {
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: rootCAs,
CurrentTime: time.Now(),
}
_, err := cert.Verify(opts)
return err
}
func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, nonce []byte) error {
for _, ext := range cert.Extensions {
if platformType, err := platformTypeFromOID(ext.Id); err == nil {
slog.Debug("Found attestation extension in peer certificate", "platform_type", platformType)
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal public key: %w", err)
}
return v.verifyCertificateExtension(ext.Value, pubKeyDER, nonce, platformType)
}
}
return fmt.Errorf("attestation extension not found in certificate")
}
func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, platformType attestation.PlatformType) error {
// Decode EAT token from certificate extension
// Note: We don't have the public key for verification here, so we decode without verification
// The signature was created by the attester, and we trust the TEE hardware verification
claims, err := eat.DecodeCBOR(extension, nil)
if err != nil {
return fmt.Errorf("failed to decode EAT token: %w", err)
}
// Verify nonce matches
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
// The attestation provider truncates the 64-byte hash to 32 bytes for the EAT token nonce claim
// This matches the Attestation Service API and standard cryptographic nonce sizes.
expectedNonce := hashNonce[:32]
// Compare nonces (EAT nonce should match our computed nonce)
if len(claims.Nonce) != len(expectedNonce) {
err := fmt.Errorf("nonce length mismatch: expected %d, got %d", len(expectedNonce), len(claims.Nonce))
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
nonceMatch := true
for i := range claims.Nonce {
if claims.Nonce[i] != expectedNonce[i] {
nonceMatch = false
break
}
}
if !nonceMatch {
err := fmt.Errorf("nonce mismatch in EAT token")
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
// Get platform verifier
verifier, err := v.verifierProvider(platformType)
if err != nil {
return fmt.Errorf("failed to get platform verifier: %w", err)
}
// Load and parse CoRIM
if attestation.AttestationPolicyPath == "" {
return fmt.Errorf("attestation policy path is not set")
}
corimBytes, err := os.ReadFile(attestation.AttestationPolicyPath)
if err != nil {
return fmt.Errorf("failed to read CoRIM file: %w", err)
}
// Try extracting from COSE Sign1 first
var unsignedCorim *corim.UnsignedCorim
var sc corim.SignedCorim
if err := sc.FromCOSE(corimBytes); err == nil {
// It's a COSE Sign1 message
unsignedCorim = &sc.UnsignedCorim
} else {
// Try parsing as unsigned CoRIM directly
var uc corim.UnsignedCorim
if err := uc.FromCBOR(corimBytes); err != nil {
return fmt.Errorf("failed to parse CoRIM (tried both signed and unsigned): %w", err)
}
unsignedCorim = &uc
}
// Re-wrap in Corim struct expected by Verifiers
// Since verifiers expect the struct from the removed internal package,
// we need to update verifiers to accept veraison/corim types
// For now, we pass the unsignedCorim directly
if err = verifier.VerifyWithCoRIM(claims.RawReport, unsignedCorim); err != nil {
return fmt.Errorf("failed to verify attestation with CoRIM: %w", err)
}
slog.Debug("CoRIM verification passed for aTLS peer certificate")
return nil
}
func platformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
switch {
case oid.Equal(SNPvTPMOID):
return attestation.SNPvTPM, nil
case oid.Equal(AzureOID):
return attestation.Azure, nil
case oid.Equal(TDXOID):
return attestation.TDX, nil
default:
return 0, fmt.Errorf("unsupported OID: %v", oid)
}
}
func platformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) {
var verifier attestation.Verifier
switch platformType {
case attestation.SNPvTPM:
verifier = vtpm.NewVerifier(nil)
case attestation.Azure:
verifier = azure.NewVerifier(nil)
case attestation.TDX:
verifier = tdx.NewVerifier()
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
return verifier, nil
}
-338
View File
@@ -1,338 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
"github.com/veraison/corim/corim"
"golang.org/x/crypto/sha3"
)
type mockVerifier struct {
verifyWithCoRIMFunc func(report []byte, manifest *corim.UnsignedCorim) error
}
func (m *mockVerifier) VerifyWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
if m.verifyWithCoRIMFunc != nil {
return m.verifyWithCoRIMFunc(report, manifest)
}
return nil
}
func TestVerifyPeerCertificate_Success(t *testing.T) {
// Setup keys and cert templates
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
require.NoError(t, err)
caCert, err := x509.ParseCertificate(caCertDER)
require.NoError(t, err)
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
// Create verifier with mock platform verifier
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) {
return &mockVerifier{
verifyWithCoRIMFunc: func(report []byte, manifest *corim.UnsignedCorim) error {
return nil
},
}, nil
}
peerKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
// Prepare EAT Claims
nonce := []byte("test-nonce")
peerPubKeyDER, err := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
require.NoError(t, err)
teeNonce := append(peerPubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
claims := eat.EATClaims{
Nonce: hashNonce[:32],
RawReport: []byte("mock-report"),
}
eatBytes, err := cbor.Marshal(claims)
require.NoError(t, err)
// Create Peer Cert with EAT extension
peerTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{CommonName: "Test Peer"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
ExtraExtensions: []pkix.Extension{
{
Id: SNPvTPMOID, // Use SNPvTPMOID as default testing OID
Value: eatBytes,
},
},
}
peerCertDER, err := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
require.NoError(t, err)
// Create dummy CoRIM file
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
c := corim.NewUnsignedCorim()
c.SetID("cocos-test-id")
corimBytes, err := c.ToCBOR()
require.NoError(t, err)
policyPath := filepath.Join(tempDir, "attestation_policy.json")
err = os.WriteFile(policyPath, corimBytes, 0o644)
require.NoError(t, err)
oldPolicyPath := attestation.AttestationPolicyPath
attestation.AttestationPolicyPath = policyPath
t.Cleanup(func() {
attestation.AttestationPolicyPath = oldPolicyPath
})
err = verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce)
assert.NoError(t, err)
}
func TestVerifyPeerCertificate_AzureSuccess(t *testing.T) {
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
caCert, _ := x509.ParseCertificate(caCertDER)
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) {
return &mockVerifier{}, nil
}
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
nonce := []byte("test-nonce")
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
teeNonce := append(peerPubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
eatBytes, _ := cbor.Marshal(claims)
peerTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{CommonName: "Azure Peer"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
ExtraExtensions: []pkix.Extension{{Id: AzureOID, Value: eatBytes}},
}
peerCertDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
tempDir := t.TempDir()
c := corim.NewUnsignedCorim()
c.SetID("cocos-test-id")
corimBytes, _ := c.ToCBOR()
policyPath := filepath.Join(tempDir, "policy.cbor")
_ = os.WriteFile(policyPath, corimBytes, 0o644)
oldPolicyPath := attestation.AttestationPolicyPath
attestation.AttestationPolicyPath = policyPath
t.Cleanup(func() { attestation.AttestationPolicyPath = oldPolicyPath })
err := verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce)
assert.NoError(t, err)
}
func TestVerifyPeerCertificate_TDXSuccess(t *testing.T) {
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
caCert, _ := x509.ParseCertificate(caCertDER)
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) {
return &mockVerifier{}, nil
}
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
nonce := []byte("test-nonce")
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
teeNonce := append(peerPubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
eatBytes, _ := cbor.Marshal(claims)
peerTemplate := &x509.Certificate{
SerialNumber: big.NewInt(3),
Subject: pkix.Name{CommonName: "TDX Peer"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
ExtraExtensions: []pkix.Extension{{Id: TDXOID, Value: eatBytes}},
}
peerCertDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
tempDir := t.TempDir()
c := corim.NewUnsignedCorim()
c.SetID("cocos-test-id")
corimBytes, _ := c.ToCBOR()
policyPath := filepath.Join(tempDir, "policy.cbor")
_ = os.WriteFile(policyPath, corimBytes, 0o644)
oldPolicyPath := attestation.AttestationPolicyPath
attestation.AttestationPolicyPath = policyPath
t.Cleanup(func() { attestation.AttestationPolicyPath = oldPolicyPath })
err := verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce)
assert.NoError(t, err)
}
func TestVerifyPeerCertificate_Failures_More(t *testing.T) {
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
caCert, _ := x509.ParseCertificate(caCertDER)
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
// Case 1: Invalid OID
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
peerTemplate := &x509.Certificate{
SerialNumber: big.NewInt(4),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
ExtraExtensions: []pkix.Extension{{Id: []int{1, 2, 3}, Value: []byte("val")}},
}
certDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
err := verifier.VerifyPeerCertificate([][]byte{certDER}, nil, []byte("nonce"))
assert.ErrorContains(t, err, "attestation extension not found")
// Case 2: Policy path not set
attestation.AttestationPolicyPath = ""
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
nonce := []byte("nonce")
teeNonce := append(peerPubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
eatBytes, _ := cbor.Marshal(claims)
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
certDERWithExt, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
err = verifier.VerifyPeerCertificate([][]byte{certDERWithExt}, nil, nonce)
assert.ErrorContains(t, err, "attestation policy path is not set")
}
func TestVerifyPeerCertificate_Failures_Ext(t *testing.T) {
rootCAs := x509.NewCertPool()
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
// Case 1: No certificates
err := verifier.VerifyPeerCertificate([][]byte{}, nil, []byte("nonce"))
assert.ErrorContains(t, err, "no certificates provided")
// Case 2: Nonce length mismatch
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
nonce := []byte("nonce")
claims := eat.EATClaims{Nonce: []byte("short"), RawReport: []byte("rep")}
eatBytes, _ := cbor.Marshal(claims)
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "CA"},
IsCA: true,
BasicConstraintsValid: true,
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
}
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
caCert, _ := x509.ParseCertificate(caCertDER)
rootCAs.AddCert(caCert)
peerTemplate := &x509.Certificate{
SerialNumber: big.NewInt(5),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
ExtraExtensions: []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}},
}
certDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce)
assert.ErrorContains(t, err, "nonce length mismatch")
// Case 3: Nonce mismatch
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
wrongTeeNonce := append(peerPubKeyDER, []byte("wrong-nonce")...)
wrongHashNonce := sha3.Sum512(wrongTeeNonce)
claims.Nonce = wrongHashNonce[:32]
eatBytes, _ = cbor.Marshal(claims)
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce)
assert.ErrorContains(t, err, "nonce mismatch in EAT token")
// Case 4: Invalid EAT (CBOR decoder failure)
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: []byte("invalid-cbor")}}
certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce)
assert.ErrorContains(t, err, "failed to decode EAT token")
}
+370
View File
@@ -0,0 +1,370 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
)
var (
ErrTruncated = errors.New("ea: truncated input")
ErrInvalidLength = errors.New("ea: invalid length")
ErrUnsupportedHandshakeType = errors.New("ea: unsupported handshake type")
ErrNotTLS13 = errors.New("ea: not TLS 1.3")
ErrUnknownCipherSuite = errors.New("ea: unknown cipher suite")
ErrContextReuse = errors.New("ea: certificate_request_context already used")
ErrInvalidRole = errors.New("ea: invalid authenticator role")
ErrUnsupportedSignatureScheme = errors.New("ea: unsupported signature scheme")
ErrSignatureMismatch = errors.New("ea: CertificateVerify signature mismatch")
ErrFinishedMismatch = errors.New("ea: Finished MAC mismatch")
ErrContextMismatch = errors.New("ea: certificate_request_context mismatch")
ErrBadRequest = errors.New("ea: bad authenticator request")
)
type ValidationResult struct {
Context []byte
Chain []*x509.Certificate
CMWAttestation []byte
Attestation *eaattestation.VerifiedPayload
Empty bool
}
func CreateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
return createAuthenticator(nil, st, role, req, nil, identity, leafEntryExtensions)
}
func CreateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
return createAuthenticator(nil, st, role, req, policy, identity, leafEntryExtensions)
}
func createAuthenticator(session *Session, st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
if st.Version != tls.VersionTLS13 {
return nil, ErrNotTLS13
}
if req == nil && role != RoleServer {
return nil, ErrInvalidRole
}
if err := validateCertificateEntryExtensions(leafEntryExtensions, req, policy); err != nil {
return nil, err
}
emptyAuthenticator := len(identity.Certificate) == 0 && identity.PrivateKey == nil
if !emptyAuthenticator && (len(identity.Certificate) == 0 || identity.PrivateKey == nil) {
return nil, ErrBadRequest
}
var reqBytes []byte
var offered []uint16
var ctx []byte
if req != nil {
var err error
reqBytes, err = req.Marshal()
if err != nil {
return nil, err
}
ctx = append([]byte(nil), req.Context...)
if schemes, ok := req.SignatureSchemes(); ok {
offered = schemes
} else {
return nil, fmt.Errorf("%w: missing signature_algorithms", ErrBadRequest)
}
} else {
c, err := NewRandomContext(32)
if err != nil {
return nil, err
}
ctx = c
}
hsCtx, h, err := ExportHandshakeContext(st, role)
if err != nil {
return nil, err
}
fk, _, err := ExportFinishedKey(st, role)
if err != nil {
return nil, err
}
if emptyAuthenticator {
if req == nil {
return nil, ErrBadRequest
}
certBytes, err := (CertificateMessage{Context: ctx}).Marshal()
if err != nil {
return nil, err
}
th := hashConcat(h, hsCtx, reqBytes, certBytes)
verifyData := hmacSum(h, fk, th)
finBytes, err := (FinishedMessage{VerifyData: verifyData}).Marshal()
if err != nil {
return nil, err
}
if err := session.MarkContextUsed(ctx); err != nil {
return nil, err
}
return finBytes, nil
}
scheme, err := chooseSignatureScheme(identity.PrivateKey, offered)
if err != nil {
return nil, err
}
if req == nil && !policyPermitsSignatureScheme(policy, scheme) {
return nil, ErrUnsupportedSignatureScheme
}
entries := make([]CertificateEntry, 0, len(identity.Certificate))
for i, der := range identity.Certificate {
exts := []Extension(nil)
if i == 0 && len(leafEntryExtensions) > 0 {
exts = leafEntryExtensions
}
entries = append(entries, CertificateEntry{CertDER: der, Extensions: exts})
}
certBytes, err := (CertificateMessage{Context: ctx, Entries: entries}).Marshal()
if err != nil {
return nil, err
}
th1 := hashConcat(h, hsCtx, reqBytes, certBytes)
sig, err := signCertVerify(identity.PrivateKey, scheme, th1)
if err != nil {
return nil, err
}
cvBytes, err := (CertificateVerifyMessage{Algorithm: scheme, Signature: sig}).Marshal()
if err != nil {
return nil, err
}
th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes)
verifyData := hmacSum(h, fk, th2)
finBytes, err := (FinishedMessage{VerifyData: verifyData}).Marshal()
if err != nil {
return nil, err
}
if err := session.MarkContextUsed(ctx); err != nil {
return nil, err
}
return append(append(certBytes, cvBytes...), finBytes...), nil
}
func ValidateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
return validateAuthenticator(nil, st, role, req, nil, nil, authBytes, verifyOpts)
}
func ValidateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
return validateAuthenticator(nil, st, role, req, policy, nil, authBytes, verifyOpts)
}
func ValidateAuthenticatorWithAttestation(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) {
return validateAuthenticator(nil, st, role, req, nil, &attPolicy, authBytes, verifyOpts)
}
func ValidateAuthenticatorWithPolicies(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) {
return validateAuthenticator(nil, st, role, req, policy, &attPolicy, authBytes, verifyOpts)
}
func validateAuthenticator(session *Session, st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, attPolicy *eaattestation.VerificationPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
if st.Version != tls.VersionTLS13 {
return nil, ErrNotTLS13
}
hsCtx, h, err := ExportHandshakeContext(st, role)
if err != nil {
return nil, err
}
fk, _, err := ExportFinishedKey(st, role)
if err != nil {
return nil, err
}
var reqBytes []byte
var offered []uint16
var reqCtx []byte
if req != nil {
reqCtx = req.Context
reqBytes, err = req.Marshal()
if err != nil {
return nil, err
}
if schemes, ok := req.SignatureSchemes(); ok {
offered = schemes
} else {
return nil, fmt.Errorf("%w: missing signature_algorithms", ErrBadRequest)
}
}
firstHm, rest, err := UnmarshalHandshakeMessage(authBytes)
if err != nil {
return nil, err
}
if firstHm.Type == HandshakeTypeFinished {
if req == nil || len(rest) != 0 {
return nil, ErrUnsupportedHandshakeType
}
finBytes, _ := MarshalHandshakeMessage(firstHm)
finMsg, _, err := UnmarshalFinishedMessage(finBytes)
if err != nil {
return nil, err
}
certBytes, err := (CertificateMessage{Context: reqCtx}).Marshal()
if err != nil {
return nil, err
}
th := hashConcat(h, hsCtx, reqBytes, certBytes)
expectedFin := hmacSum(h, fk, th)
if !constantTimeEqual(expectedFin, finMsg.VerifyData) {
return nil, ErrFinishedMismatch
}
if err := session.MarkContextUsed(reqCtx); err != nil {
return nil, err
}
return &ValidationResult{
Context: append([]byte(nil), reqCtx...),
Empty: true,
}, nil
}
if firstHm.Type != HandshakeTypeCertificate {
return nil, ErrUnsupportedHandshakeType
}
certHm := firstHm
certBytes, _ := MarshalHandshakeMessage(certHm)
cvHm, rest, err := UnmarshalHandshakeMessage(rest)
if err != nil || cvHm.Type != HandshakeTypeCertificateVerify {
return nil, ErrUnsupportedHandshakeType
}
cvBytes, _ := MarshalHandshakeMessage(cvHm)
finHm, rest, err := UnmarshalHandshakeMessage(rest)
if err != nil || finHm.Type != HandshakeTypeFinished || len(rest) != 0 {
return nil, ErrInvalidLength
}
finBytes, _ := MarshalHandshakeMessage(finHm)
certMsg, _, err := UnmarshalCertificateMessage(certBytes)
if err != nil {
return nil, err
}
cvMsg, _, err := UnmarshalCertificateVerifyMessage(cvBytes)
if err != nil {
return nil, err
}
finMsg, _, err := UnmarshalFinishedMessage(finBytes)
if err != nil {
return nil, err
}
if req != nil && !bytes.Equal(certMsg.Context, reqCtx) {
return nil, ErrContextMismatch
}
if len(certMsg.Entries) == 0 {
// Empty authenticators are encoded as Finished-only. A Certificate
// message with zero entries followed by CertificateVerify/Finished is
// malformed and must not be accepted as an empty authenticator.
return nil, ErrUnsupportedHandshakeType
}
if err := ValidateCMWAttestationPlacement(certMsg.Entries); err != nil {
return nil, err
}
for _, entry := range certMsg.Entries {
if err := validateCertificateEntryExtensions(entry.Extensions, req, policy); err != nil {
return nil, err
}
}
extracted, present, err := ExtractCMWAttestationFromExtensions(certMsg.Entries[0].Extensions)
if err != nil {
return nil, err
}
if present && req != nil && !RequestPermitsCertificateExtension(req, CMWAttestationExtensionType) {
return nil, ErrBadRequest
}
if present && req == nil && !PolicyPermitsCertificateExtension(policy, CMWAttestationExtensionType) {
return nil, ErrBadRequest
}
chain := make([]*x509.Certificate, 0, len(certMsg.Entries))
for _, e := range certMsg.Entries {
c, err := x509.ParseCertificate(e.CertDER)
if err != nil {
return nil, err
}
chain = append(chain, c)
}
leaf := chain[0]
if req != nil {
ok := false
for _, s := range offered {
if s == cvMsg.Algorithm {
ok = true
break
}
}
if !ok {
return nil, ErrUnsupportedSignatureScheme
}
}
th1 := hashConcat(h, hsCtx, reqBytes, certBytes)
if err := verifyCertVerify(leaf.PublicKey, cvMsg.Algorithm, th1, cvMsg.Signature); err != nil {
return nil, err
}
th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes)
expectedFin := hmacSum(h, fk, th2)
if !constantTimeEqual(expectedFin, finMsg.VerifyData) {
return nil, ErrFinishedMismatch
}
if verifyOpts != nil {
opts := *verifyOpts
if opts.Intermediates == nil {
opts.Intermediates = x509.NewCertPool()
}
for _, ic := range chain[1:] {
opts.Intermediates.AddCert(ic)
}
if _, err := leaf.Verify(opts); err != nil {
return nil, err
}
}
if err := session.MarkContextUsed(certMsg.Context); err != nil {
return nil, err
}
res := &ValidationResult{
Context: append([]byte(nil), certMsg.Context...),
Chain: chain,
}
if present {
res.CMWAttestation = extracted
parsed, err := eaattestation.ParsePayload(extracted)
if err != nil {
return nil, err
}
var verifierPolicy eaattestation.VerificationPolicy
// A nil attestation policy is intentional: VerifyPayload then fails closed
// for any payload that carries evidence or attestation results without
// explicit verifiers being configured.
if attPolicy != nil {
verifierPolicy = *attPolicy
}
verified, err := eaattestation.VerifyPayload(st, eaattestation.ExporterLabelAttestation, certMsg.Context, leaf, parsed, verifierPolicy)
if err != nil {
return nil, err
}
res.Attestation = verified
}
return res, nil
}
+537
View File
@@ -0,0 +1,537 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"math/big"
"net"
"testing"
"time"
attestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
)
func selfSignedCert(t *testing.T) tls.Certificate {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "ea-phase3"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
t.Fatal(err)
}
return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: priv}
}
func tlsPair(t *testing.T, cert tls.Certificate) (srv, cli *tls.Conn) {
t.Helper()
srvConf := &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13}
cliConf := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13}
a, b := net.Pipe()
srv = tls.Server(a, srvConf)
cli = tls.Client(b, cliConf)
errCh := make(chan error, 2)
go func() { errCh <- srv.Handshake() }()
go func() { errCh <- cli.Handshake() }()
for i := 0; i < 2; i++ {
if err := <-errCh; err != nil {
t.Fatalf("handshake: %v", err)
}
}
return srv, cli
}
type acceptEvidenceVerifier struct{}
func (acceptEvidenceVerifier) VerifyEvidence(evidence []byte) error { return nil }
func TestDummyAttestationRoundTrip(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
ctx, _ := NewRandomContext(16)
req := &AuthenticatorRequest{
Type: HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
CMWAttestationOfferExtension(),
},
}
leaf, _ := x509.ParseCertificate(cert.Certificate[0])
srvState := srv.ConnectionState()
_, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf)
if err != nil {
t.Fatal(err)
}
payloadBytes, err := attestation.MarshalPayload(attestation.Payload{
Version: 1,
Evidence: []byte("dummy-attestation-report"),
MediaType: "application/eat+cwt",
Binder: attestation.AttestationBinder{
ExporterLabel: attestation.ExporterLabelAttestation,
AIKPubHash: aikPubHash,
Binding: binding,
},
})
if err != nil {
t.Fatal(err)
}
ext, err := CMWAttestationDataExtension(payloadBytes)
if err != nil {
t.Fatal(err)
}
auth, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext})
if err != nil {
t.Fatal(err)
}
cliState := cli.ConnectionState()
roots := x509.NewCertPool()
roots.AddCert(leaf)
res, err := ValidateAuthenticatorWithAttestation(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}, attestation.VerificationPolicy{
EvidenceVerifier: acceptEvidenceVerifier{},
})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(res.CMWAttestation, payloadBytes) {
t.Fatalf("cmw mismatch")
}
if res.Attestation == nil || !res.Attestation.BindingVerified || !res.Attestation.EvidenceVerified {
t.Fatalf("expected verified attestation result")
}
}
func TestRejectIfNotOffered(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
ctx, _ := NewRandomContext(8)
req := &AuthenticatorRequest{
Type: HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
},
}
ext, _ := CMWAttestationDataExtension([]byte("dummy"))
srvState := srv.ConnectionState()
if _, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext}); err == nil {
t.Fatalf("expected error when cmw_attestation not offered")
}
}
func TestAttestationFailsClosedWithoutVerifier(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
ctx, _ := NewRandomContext(16)
req := &AuthenticatorRequest{
Type: HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
CMWAttestationOfferExtension(),
},
}
leaf, _ := x509.ParseCertificate(cert.Certificate[0])
srvState := srv.ConnectionState()
_, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf)
if err != nil {
t.Fatal(err)
}
payloadBytes, err := attestation.MarshalPayload(attestation.Payload{
Version: 1,
Evidence: []byte("dummy-attestation-report"),
Binder: attestation.AttestationBinder{
ExporterLabel: attestation.ExporterLabelAttestation,
AIKPubHash: aikPubHash,
Binding: binding,
},
})
if err != nil {
t.Fatal(err)
}
ext, err := CMWAttestationDataExtension(payloadBytes)
if err != nil {
t.Fatal(err)
}
auth, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext})
if err != nil {
t.Fatal(err)
}
cliState := cli.ConnectionState()
if _, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil); err != attestation.ErrEvidenceVerificationMissing {
t.Fatalf("got %v, want %v", err, attestation.ErrEvidenceVerificationMissing)
}
}
func TestRejectCMWAttestationOnIntermediateEntry(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
ctx, _ := NewRandomContext(16)
req := &AuthenticatorRequest{
Type: HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
CMWAttestationOfferExtension(),
},
}
leaf, _ := x509.ParseCertificate(cert.Certificate[0])
srvState := srv.ConnectionState()
hsCtx, h, err := ExportHandshakeContext(&srvState, RoleServer)
if err != nil {
t.Fatal(err)
}
fk, _, err := ExportFinishedKey(&srvState, RoleServer)
if err != nil {
t.Fatal(err)
}
reqBytes, err := req.Marshal()
if err != nil {
t.Fatal(err)
}
_, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf)
if err != nil {
t.Fatal(err)
}
payloadBytes, err := attestation.MarshalPayload(attestation.Payload{
Version: 1,
Evidence: []byte("dummy-attestation-report"),
MediaType: "application/eat+cwt",
Binder: attestation.AttestationBinder{
ExporterLabel: attestation.ExporterLabelAttestation,
AIKPubHash: aikPubHash,
Binding: binding,
},
})
if err != nil {
t.Fatal(err)
}
ext, err := CMWAttestationDataExtension(payloadBytes)
if err != nil {
t.Fatal(err)
}
certBytes, err := (CertificateMessage{
Context: ctx,
Entries: []CertificateEntry{
{CertDER: cert.Certificate[0]},
{CertDER: cert.Certificate[0], Extensions: []Extension{ext}},
},
}).Marshal()
if err != nil {
t.Fatal(err)
}
th1 := hashConcat(h, hsCtx, reqBytes, certBytes)
sig, err := signCertVerify(cert.PrivateKey, uint16(tls.ECDSAWithP256AndSHA256), th1)
if err != nil {
t.Fatal(err)
}
cvBytes, err := (CertificateVerifyMessage{
Algorithm: uint16(tls.ECDSAWithP256AndSHA256),
Signature: sig,
}).Marshal()
if err != nil {
t.Fatal(err)
}
th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes)
finBytes, err := (FinishedMessage{VerifyData: hmacSum(h, fk, th2)}).Marshal()
if err != nil {
t.Fatal(err)
}
auth := append(append(certBytes, cvBytes...), finBytes...)
cliState := cli.ConnectionState()
if _, err := ValidateAuthenticatorWithAttestation(&cliState, RoleServer, req, auth, nil, attestation.VerificationPolicy{
EvidenceVerifier: acceptEvidenceVerifier{},
}); err != ErrBadRequest {
t.Fatalf("got %v, want %v", err, ErrBadRequest)
}
}
func TestSessionRejectsContextReuse(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
ctx, _ := NewRandomContext(12)
req := &AuthenticatorRequest{
Type: HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
},
}
createSession := NewSession()
srvState := srv.ConnectionState()
auth, err := createSession.CreateAuthenticator(&srvState, RoleServer, req, cert, nil)
if err != nil {
t.Fatal(err)
}
if _, err := createSession.CreateAuthenticator(&srvState, RoleServer, req, cert, nil); err != ErrContextReuse {
t.Fatalf("got %v, want %v", err, ErrContextReuse)
}
validateSession := NewSession()
cliState := cli.ConnectionState()
roots := x509.NewCertPool()
leaf, _ := x509.ParseCertificate(cert.Certificate[0])
roots.AddCert(leaf)
if _, err := validateSession.ValidateAuthenticator(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}); err != nil {
t.Fatal(err)
}
if _, err := validateSession.ValidateAuthenticator(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}); err != ErrContextReuse {
t.Fatalf("got %v, want %v", err, ErrContextReuse)
}
}
func TestEmptyAuthenticatorRoundTrip(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
ctx, _ := NewRandomContext(10)
req := &AuthenticatorRequest{
Type: HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
},
}
srvState := srv.ConnectionState()
auth, err := CreateAuthenticator(&srvState, RoleServer, req, tls.Certificate{}, nil)
if err != nil {
t.Fatal(err)
}
cliState := cli.ConnectionState()
res, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil)
if err != nil {
t.Fatal(err)
}
if !res.Empty {
t.Fatalf("expected empty authenticator result")
}
if len(res.Chain) != 0 {
t.Fatalf("expected no certificate chain")
}
}
func TestRejectCertificateMessageWithEmptyEntries(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
ctx, _ := NewRandomContext(10)
req := &AuthenticatorRequest{
Type: HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
},
}
certBytes, err := (CertificateMessage{Context: ctx}).Marshal()
if err != nil {
t.Fatal(err)
}
cvBytes, err := (CertificateVerifyMessage{
Algorithm: uint16(tls.ECDSAWithP256AndSHA256),
Signature: []byte{0x01},
}).Marshal()
if err != nil {
t.Fatal(err)
}
finBytes, err := (FinishedMessage{VerifyData: []byte{0x01}}).Marshal()
if err != nil {
t.Fatal(err)
}
auth := append(append(certBytes, cvBytes...), finBytes...)
cliState := cli.ConnectionState()
if _, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil); err != ErrUnsupportedHandshakeType {
t.Fatalf("got %v, want %v", err, ErrUnsupportedHandshakeType)
}
}
func TestRejectSpontaneousClientAuthenticator(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
cliState := cli.ConnectionState()
if _, err := CreateAuthenticator(&cliState, RoleClient, nil, cert, nil); err != ErrInvalidRole {
t.Fatalf("got %v, want %v", err, ErrInvalidRole)
}
}
func TestRequestParsers(t *testing.T) {
oidDER, err := asn1.Marshal(asn1.ObjectIdentifier{2, 5, 4, 3})
if err != nil {
t.Fatal(err)
}
oidFilterPayload := append([]byte{byte(len(oidDER))}, oidDER...)
oidFilterPayload = append(oidFilterPayload, 0x00, 0x02, 'o', 'k')
req := AuthenticatorRequest{
Type: HandshakeTypeCertificateRequest,
Context: []byte{1, 2, 3},
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x04, 0x04, 0x03, 0x08, 0x07}},
{Type: SignatureAlgorithmsCertExtensionType, Data: []byte{0x00, 0x02, 0x08, 0x07}},
{Type: CertificateAuthoritiesExtensionType, Data: []byte{0x00, 0x06, 0x00, 0x04, 't', 'e', 's', 't'}},
{Type: OIDFiltersExtensionType, Data: append([]byte{0x00, byte(len(oidFilterPayload))}, oidFilterPayload...)},
},
}
if got, ok := req.SignatureSchemes(); !ok || len(got) != 2 || got[0] != uint16(tls.ECDSAWithP256AndSHA256) || got[1] != uint16(tls.Ed25519) {
t.Fatalf("unexpected signature schemes: %v %v", got, ok)
}
if got, ok := req.SignatureSchemesCert(); !ok || len(got) != 1 || got[0] != uint16(tls.Ed25519) {
t.Fatalf("unexpected signature_algorithms_cert: %v %v", got, ok)
}
if got, ok := req.CertificateAuthorities(); !ok || len(got) != 1 || string(got[0]) != "test" {
t.Fatalf("unexpected certificate authorities: %q %v", got, ok)
}
if got, ok := req.OIDFilters(); !ok || len(got) != 1 || !got[0].OID.Equal(asn1.ObjectIdentifier{2, 5, 4, 3}) || string(got[0].Values) != "ok" {
t.Fatalf("unexpected oid filters: %#v %v", got, ok)
}
}
func TestRejectLeafExtensionNotPermittedByRequest(t *testing.T) {
cert := selfSignedCert(t)
srv, _ := tlsPair(t, cert)
defer srv.Close()
ctx, _ := NewRandomContext(8)
req := &AuthenticatorRequest{
Type: HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []Extension{
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
},
}
srvState := srv.ConnectionState()
ext := Extension{Type: 0x1234, Data: []byte{0x00}}
if _, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext}); err == nil {
t.Fatalf("expected policy error for unpermitted leaf extension")
}
}
func TestSpontaneousPolicyPermitsCertificateExtension(t *testing.T) {
cert := selfSignedCert(t)
srv, cli := tlsPair(t, cert)
defer srv.Close()
defer cli.Close()
payloadBytes, err := attestation.MarshalPayload(attestation.Payload{
Version: 1,
Evidence: []byte("dummy-attestation-report"),
MediaType: "application/eat+cwt",
Binder: attestation.AttestationBinder{
ExporterLabel: attestation.ExporterLabelAttestation,
AIKPubHash: []byte("placeholder-aik"),
Binding: []byte("placeholder-binding"),
},
})
if err != nil {
t.Fatal(err)
}
ext, err := CMWAttestationDataExtension(payloadBytes)
if err != nil {
t.Fatal(err)
}
policy := &SpontaneousAuthenticatorPolicy{
AllowedSignatureSchemes: []uint16{uint16(tls.ECDSAWithP256AndSHA256)},
AllowedCertificateExtensions: []uint16{CMWAttestationExtensionType},
}
srvState := srv.ConnectionState()
auth, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, []Extension{ext})
if err != nil {
t.Fatal(err)
}
if len(auth) == 0 {
t.Fatalf("expected authenticator bytes")
}
}
func TestSpontaneousPolicyRejectsCertificateExtension(t *testing.T) {
cert := selfSignedCert(t)
srv, _ := tlsPair(t, cert)
defer srv.Close()
ext, err := CMWAttestationDataExtension([]byte("dummy-attestation-report"))
if err != nil {
t.Fatal(err)
}
policy := &SpontaneousAuthenticatorPolicy{
AllowedSignatureSchemes: []uint16{uint16(tls.ECDSAWithP256AndSHA256)},
}
srvState := srv.ConnectionState()
if _, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, []Extension{ext}); err == nil {
t.Fatalf("expected policy error for unpermitted spontaneous extension")
}
}
func TestSpontaneousPolicyRejectsSignatureScheme(t *testing.T) {
cert := selfSignedCert(t)
srv, _ := tlsPair(t, cert)
defer srv.Close()
policy := &SpontaneousAuthenticatorPolicy{
AllowedSignatureSchemes: []uint16{uint16(tls.Ed25519)},
}
srvState := srv.ConnectionState()
if _, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, nil); err != ErrUnsupportedSignatureScheme {
t.Fatalf("got %v, want %v", err, ErrUnsupportedSignatureScheme)
}
}
+95
View File
@@ -0,0 +1,95 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
type CertificateMessage struct {
Context []byte
Entries []CertificateEntry
}
type CertificateEntry struct {
CertDER []byte
Extensions []Extension
}
func (m CertificateMessage) Marshal() ([]byte, error) {
if len(m.Context) > 255 {
return nil, ErrInvalidLength
}
var listPayload []byte
for _, e := range m.Entries {
if len(e.CertDER) == 0 || len(e.CertDER) > 0xFFFFFF {
return nil, ErrInvalidLength
}
extVec, err := MarshalExtensions(e.Extensions)
if err != nil {
return nil, err
}
entry := make([]byte, 3+len(e.CertDER)+len(extVec))
putUint24(entry[0:3], uint32(len(e.CertDER)))
copy(entry[3:], e.CertDER)
copy(entry[3+len(e.CertDER):], extVec)
listPayload = append(listPayload, entry...)
}
body := make([]byte, 1+len(m.Context)+3+len(listPayload))
body[0] = byte(len(m.Context))
copy(body[1:], m.Context)
putUint24(body[1+len(m.Context):1+len(m.Context)+3], uint32(len(listPayload)))
copy(body[1+len(m.Context)+3:], listPayload)
return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeCertificate, Body: body})
}
func UnmarshalCertificateMessage(handshakeBytes []byte) (CertificateMessage, []byte, error) {
hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes)
if err != nil {
return CertificateMessage{}, nil, err
}
if len(rest) != 0 || hm.Type != HandshakeTypeCertificate {
return CertificateMessage{}, nil, ErrInvalidLength
}
if len(hm.Body) < 1 {
return CertificateMessage{}, nil, ErrTruncated
}
ctxLen := int(hm.Body[0])
if len(hm.Body) < 1+ctxLen+3 {
return CertificateMessage{}, nil, ErrTruncated
}
ctx := append([]byte(nil), hm.Body[1:1+ctxLen]...)
listLen := int(readUint24(hm.Body[1+ctxLen : 1+ctxLen+3]))
if len(hm.Body) != 1+ctxLen+3+listLen {
return CertificateMessage{}, nil, ErrInvalidLength
}
list := hm.Body[1+ctxLen+3:]
var entries []CertificateEntry
for i := 0; i < len(list); {
if len(list)-i < 3 {
return CertificateMessage{}, nil, ErrTruncated
}
certLen := int(readUint24(list[i : i+3]))
i += 3
if certLen <= 0 || certLen > len(list)-i {
return CertificateMessage{}, nil, ErrInvalidLength
}
certDER := append([]byte(nil), list[i:i+certLen]...)
i += certLen
exts, leftover, err := UnmarshalExtensions(list[i:])
if err != nil {
return CertificateMessage{}, nil, err
}
consumed := len(list[i:]) - len(leftover)
i += consumed
entries = append(entries, CertificateEntry{CertDER: certDER, Extensions: exts})
}
raw, _ := MarshalHandshakeMessage(hm)
return CertificateMessage{Context: ctx, Entries: entries}, raw, nil
}
+148
View File
@@ -0,0 +1,148 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
)
type CertificateVerifyMessage struct {
Algorithm uint16
Signature []byte
}
func (m CertificateVerifyMessage) Marshal() ([]byte, error) {
if len(m.Signature) > 0xFFFF {
return nil, ErrInvalidLength
}
body := make([]byte, 4+len(m.Signature))
putUint16(body[0:2], m.Algorithm)
putUint16(body[2:4], uint16(len(m.Signature)))
copy(body[4:], m.Signature)
return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeCertificateVerify, Body: body})
}
func UnmarshalCertificateVerifyMessage(handshakeBytes []byte) (CertificateVerifyMessage, []byte, error) {
hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes)
if err != nil {
return CertificateVerifyMessage{}, nil, err
}
if len(rest) != 0 || hm.Type != HandshakeTypeCertificateVerify {
return CertificateVerifyMessage{}, nil, ErrInvalidLength
}
if len(hm.Body) < 4 {
return CertificateVerifyMessage{}, nil, ErrTruncated
}
alg := readUint16(hm.Body[0:2])
sigLen := int(readUint16(hm.Body[2:4]))
if len(hm.Body) != 4+sigLen {
return CertificateVerifyMessage{}, nil, ErrInvalidLength
}
sig := append([]byte(nil), hm.Body[4:]...)
raw, _ := MarshalHandshakeMessage(hm)
return CertificateVerifyMessage{Algorithm: alg, Signature: sig}, raw, nil
}
var eaContextString = []byte("Exported Authenticator")
func buildCertVerifyInput(transcriptHash []byte) []byte {
prefix := bytes.Repeat([]byte{0x20}, 64)
out := make([]byte, 0, len(prefix)+len(eaContextString)+1+len(transcriptHash))
out = append(out, prefix...)
out = append(out, eaContextString...)
out = append(out, 0x00)
out = append(out, transcriptHash...)
return out
}
func signCertVerify(priv any, scheme uint16, transcriptHash []byte) ([]byte, error) {
info, err := signatureSchemeInfo(scheme)
if err != nil {
return nil, err
}
msg := buildCertVerifyInput(transcriptHash)
switch info.Alg {
case sigAlgECDSA:
k, ok := priv.(*ecdsa.PrivateKey)
if !ok {
return nil, ErrUnsupportedSignatureScheme
}
h := info.Hash.New()
h.Write(msg)
return ecdsa.SignASN1(rand.Reader, k, h.Sum(nil))
case sigAlgRSAPSS:
k, ok := priv.(*rsa.PrivateKey)
if !ok {
return nil, ErrUnsupportedSignatureScheme
}
h := info.Hash.New()
h.Write(msg)
return rsa.SignPSS(rand.Reader, k, info.Hash, h.Sum(nil),
&rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: info.Hash})
case sigAlgEd25519:
k, ok := priv.(ed25519.PrivateKey)
if !ok {
return nil, ErrUnsupportedSignatureScheme
}
return ed25519.Sign(k, msg), nil
default:
return nil, ErrUnsupportedSignatureScheme
}
}
func verifyCertVerify(pub any, scheme uint16, transcriptHash []byte, signature []byte) error {
info, err := signatureSchemeInfo(scheme)
if err != nil {
return err
}
msg := buildCertVerifyInput(transcriptHash)
switch info.Alg {
case sigAlgECDSA:
k, ok := pub.(*ecdsa.PublicKey)
if !ok {
return ErrUnsupportedSignatureScheme
}
h := info.Hash.New()
h.Write(msg)
if !ecdsa.VerifyASN1(k, h.Sum(nil), signature) {
return ErrSignatureMismatch
}
return nil
case sigAlgRSAPSS:
k, ok := pub.(*rsa.PublicKey)
if !ok {
return ErrUnsupportedSignatureScheme
}
h := info.Hash.New()
h.Write(msg)
if err := rsa.VerifyPSS(k, info.Hash, h.Sum(nil), signature,
&rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: info.Hash}); err != nil {
return ErrSignatureMismatch
}
return nil
case sigAlgEd25519:
k, ok := pub.(ed25519.PublicKey)
if !ok {
return ErrUnsupportedSignatureScheme
}
if !ed25519.Verify(k, msg, signature) {
return ErrSignatureMismatch
}
return nil
default:
return ErrUnsupportedSignatureScheme
}
}
+51
View File
@@ -0,0 +1,51 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
const CMWAttestationExtensionType uint16 = 0xFF00
func CMWAttestationOfferExtension() Extension {
return Extension{Type: CMWAttestationExtensionType, Data: nil}
}
func CMWAttestationDataExtension(cmw []byte) (Extension, error) {
if len(cmw) == 0 || len(cmw) > 0xFFFF {
return Extension{}, ErrInvalidLength
}
data := make([]byte, 2+len(cmw))
putUint16(data[0:2], uint16(len(cmw)))
copy(data[2:], cmw)
return Extension{Type: CMWAttestationExtensionType, Data: data}, nil
}
func ExtractCMWAttestationFromExtensions(exts []Extension) ([]byte, bool, error) {
for _, e := range exts {
if e.Type != CMWAttestationExtensionType {
continue
}
if len(e.Data) < 2 {
return nil, true, ErrInvalidLength
}
l := int(readUint16(e.Data[0:2]))
if l <= 0 || l != len(e.Data)-2 {
return nil, true, ErrInvalidLength
}
return append([]byte(nil), e.Data[2:]...), true, nil
}
return nil, false, nil
}
func ValidateCMWAttestationPlacement(entries []CertificateEntry) error {
for i, entry := range entries {
for _, ext := range entry.Extensions {
if ext.Type != CMWAttestationExtensionType {
continue
}
if i != 0 {
return ErrBadRequest
}
}
}
return nil
}
+67
View File
@@ -0,0 +1,67 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
import (
"crypto"
"crypto/tls"
"fmt"
)
const (
LabelClientAuthenticatorHandshakeContext = "EXPORTER-client authenticator handshake context"
LabelServerAuthenticatorHandshakeContext = "EXPORTER-server authenticator handshake context"
LabelClientAuthenticatorFinishedKey = "EXPORTER-client authenticator finished key"
LabelServerAuthenticatorFinishedKey = "EXPORTER-server authenticator finished key"
)
type Role uint8
const (
RoleClient Role = iota + 1
RoleServer
)
func AuthenticatorHashTLS13(cipherSuite uint16) (crypto.Hash, error) {
switch cipherSuite {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256:
return crypto.SHA256, nil
case tls.TLS_AES_256_GCM_SHA384:
return crypto.SHA384, nil
default:
return 0, fmt.Errorf("%w: %s (0x%04x)", ErrUnknownCipherSuite, tls.CipherSuiteName(cipherSuite), cipherSuite)
}
}
func ExportHandshakeContext(st *tls.ConnectionState, role Role) ([]byte, crypto.Hash, error) {
if st.Version != tls.VersionTLS13 {
return nil, 0, ErrNotTLS13
}
h, err := AuthenticatorHashTLS13(st.CipherSuite)
if err != nil {
return nil, 0, err
}
label := LabelClientAuthenticatorHandshakeContext
if role == RoleServer {
label = LabelServerAuthenticatorHandshakeContext
}
out, err := st.ExportKeyingMaterial(label, nil, h.Size())
return out, h, err
}
func ExportFinishedKey(st *tls.ConnectionState, role Role) ([]byte, crypto.Hash, error) {
if st.Version != tls.VersionTLS13 {
return nil, 0, ErrNotTLS13
}
h, err := AuthenticatorHashTLS13(st.CipherSuite)
if err != nil {
return nil, 0, err
}
label := LabelClientAuthenticatorFinishedKey
if role == RoleServer {
label = LabelServerAuthenticatorFinishedKey
}
out, err := st.ExportKeyingMaterial(label, nil, h.Size())
return out, h, err
}
+61
View File
@@ -0,0 +1,61 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
type Extension struct {
Type uint16
Data []byte
}
func MarshalExtensions(exts []Extension) ([]byte, error) {
payloadLen := 0
for _, e := range exts {
if len(e.Data) > 0xFFFF {
return nil, ErrInvalidLength
}
payloadLen += 4 + len(e.Data)
if payloadLen > 0xFFFF {
return nil, ErrInvalidLength
}
}
out := make([]byte, 2+payloadLen)
putUint16(out[0:2], uint16(payloadLen))
off := 2
for _, e := range exts {
putUint16(out[off:off+2], e.Type)
putUint16(out[off+2:off+4], uint16(len(e.Data)))
copy(out[off+4:], e.Data)
off += 4 + len(e.Data)
}
return out, nil
}
func UnmarshalExtensions(b []byte) (exts []Extension, rest []byte, err error) {
if len(b) < 2 {
return nil, nil, ErrTruncated
}
total := int(readUint16(b[0:2]))
if len(b) < 2+total {
return nil, nil, ErrTruncated
}
payload := b[2 : 2+total]
rest = b[2+total:]
i := 0
for i < len(payload) {
if len(payload)-i < 4 {
return nil, nil, ErrTruncated
}
typ := readUint16(payload[i : i+2])
l := int(readUint16(payload[i+2 : i+4]))
i += 4
if l < 0 || l > len(payload)-i {
return nil, nil, ErrInvalidLength
}
data := make([]byte, l)
copy(data, payload[i:i+l])
i += l
exts = append(exts, Extension{Type: typ, Data: data})
}
return exts, rest, nil
}
+27
View File
@@ -0,0 +1,27 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
type FinishedMessage struct {
VerifyData []byte
}
func (m FinishedMessage) Marshal() ([]byte, error) {
if len(m.VerifyData) == 0 {
return nil, ErrInvalidLength
}
return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeFinished, Body: append([]byte(nil), m.VerifyData...)})
}
func UnmarshalFinishedMessage(handshakeBytes []byte) (FinishedMessage, []byte, error) {
hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes)
if err != nil {
return FinishedMessage{}, nil, err
}
if len(rest) != 0 || hm.Type != HandshakeTypeFinished || len(hm.Body) == 0 {
return FinishedMessage{}, nil, ErrInvalidLength
}
raw, _ := MarshalHandshakeMessage(hm)
return FinishedMessage{VerifyData: append([]byte(nil), hm.Body...)}, raw, nil
}
+55
View File
@@ -0,0 +1,55 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
const handshakeHeaderLen = 4 // 1 byte type + 3 byte uint24 len
const (
HandshakeTypeCertificate uint8 = 11
HandshakeTypeCertificateRequest uint8 = 13
HandshakeTypeCertificateVerify uint8 = 15
HandshakeTypeClientCertificateRequest uint8 = 17
HandshakeTypeFinished uint8 = 20
)
type HandshakeMessage struct {
Type uint8
Body []byte
}
func MarshalHandshakeMessage(m HandshakeMessage) ([]byte, error) {
if len(m.Body) > 0xFFFFFF {
return nil, ErrInvalidLength
}
out := make([]byte, handshakeHeaderLen+len(m.Body))
out[0] = m.Type
putUint24(out[1:4], uint32(len(m.Body)))
copy(out[4:], m.Body)
return out, nil
}
func UnmarshalHandshakeMessage(b []byte) (msg HandshakeMessage, rest []byte, err error) {
if len(b) < handshakeHeaderLen {
return HandshakeMessage{}, nil, ErrTruncated
}
t := b[0]
n := int(readUint24(b[1:4]))
if n < 0 || n > 0xFFFFFF {
return HandshakeMessage{}, nil, ErrInvalidLength
}
if len(b) < handshakeHeaderLen+n {
return HandshakeMessage{}, nil, ErrTruncated
}
body := make([]byte, n)
copy(body, b[4:4+n])
return HandshakeMessage{Type: t, Body: body}, b[4+n:], nil
}
func putUint24(dst []byte, v uint32) { dst[0] = byte(v >> 16); dst[1] = byte(v >> 8); dst[2] = byte(v) }
func readUint24(src []byte) uint32 {
return (uint32(src[0]) << 16) | (uint32(src[1]) << 8) | uint32(src[2])
}
func putUint16(dst []byte, v uint16) { dst[0] = byte(v >> 8); dst[1] = byte(v) }
func readUint16(src []byte) uint16 { return (uint16(src[0]) << 8) | uint16(src[1]) }
+60
View File
@@ -0,0 +1,60 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
type SpontaneousAuthenticatorPolicy struct {
AllowedSignatureSchemes []uint16
AllowedCertificateExtensions []uint16
}
func RequestPermitsCertificateExtension(req *AuthenticatorRequest, typ uint16) bool {
if req == nil {
return false
}
for _, e := range req.Extensions {
if e.Type == typ {
return true
}
}
return false
}
func PolicyPermitsCertificateExtension(policy *SpontaneousAuthenticatorPolicy, typ uint16) bool {
if policy == nil {
return false
}
for _, allowed := range policy.AllowedCertificateExtensions {
if allowed == typ {
return true
}
}
return false
}
func policyPermitsSignatureScheme(policy *SpontaneousAuthenticatorPolicy, scheme uint16) bool {
if policy == nil || len(policy.AllowedSignatureSchemes) == 0 {
return true
}
for _, allowed := range policy.AllowedSignatureSchemes {
if allowed == scheme {
return true
}
}
return false
}
func validateCertificateEntryExtensions(exts []Extension, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy) error {
for _, e := range exts {
if req != nil {
if !RequestPermitsCertificateExtension(req, e.Type) {
return ErrBadRequest
}
continue
}
if !PolicyPermitsCertificateExtension(policy, e.Type) {
return ErrBadRequest
}
}
return nil
}
+214
View File
@@ -0,0 +1,214 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
import (
"crypto/rand"
"encoding/asn1"
)
const (
SignatureAlgorithmsExtensionType uint16 = 0x000d
ServerNameExtensionType uint16 = 0x0000
CertificateAuthoritiesExtensionType uint16 = 0x002f
OIDFiltersExtensionType uint16 = 0x0030
SignatureAlgorithmsCertExtensionType uint16 = 0x0032
)
type AuthenticatorRequest struct {
Type uint8
Context []byte
Extensions []Extension
}
type OIDFilter struct {
OID asn1.ObjectIdentifier
Values []byte
}
func NewRandomContext(n int) ([]byte, error) {
if n <= 0 || n > 255 {
return nil, ErrInvalidLength
}
b := make([]byte, n)
_, err := rand.Read(b)
return b, err
}
func (r AuthenticatorRequest) Marshal() ([]byte, error) {
if r.Type != HandshakeTypeCertificateRequest && r.Type != HandshakeTypeClientCertificateRequest {
return nil, ErrUnsupportedHandshakeType
}
if len(r.Context) > 255 {
return nil, ErrInvalidLength
}
extVec, err := MarshalExtensions(r.Extensions)
if err != nil {
return nil, err
}
body := make([]byte, 1+len(r.Context)+len(extVec))
body[0] = byte(len(r.Context))
copy(body[1:], r.Context)
copy(body[1+len(r.Context):], extVec)
return MarshalHandshakeMessage(HandshakeMessage{Type: r.Type, Body: body})
}
func UnmarshalAuthenticatorRequest(handshakeBytes []byte) (AuthenticatorRequest, []byte, error) {
hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes)
if err != nil {
return AuthenticatorRequest{}, nil, err
}
if hm.Type != HandshakeTypeCertificateRequest && hm.Type != HandshakeTypeClientCertificateRequest {
return AuthenticatorRequest{}, nil, ErrUnsupportedHandshakeType
}
if len(hm.Body) < 1 {
return AuthenticatorRequest{}, nil, ErrTruncated
}
ctxLen := int(hm.Body[0])
if len(hm.Body) < 1+ctxLen {
return AuthenticatorRequest{}, nil, ErrTruncated
}
ctx := append([]byte(nil), hm.Body[1:1+ctxLen]...)
exts, leftover, err := UnmarshalExtensions(hm.Body[1+ctxLen:])
if err != nil {
return AuthenticatorRequest{}, nil, err
}
if len(leftover) != 0 {
return AuthenticatorRequest{}, nil, ErrInvalidLength
}
return AuthenticatorRequest{
Type: hm.Type,
Context: ctx,
Extensions: exts,
}, rest, nil
}
func (r AuthenticatorRequest) SignatureSchemes() ([]uint16, bool) {
return parseSignatureSchemesExtension(r.Extensions, SignatureAlgorithmsExtensionType)
}
func (r AuthenticatorRequest) SignatureSchemesCert() ([]uint16, bool) {
return parseSignatureSchemesExtension(r.Extensions, SignatureAlgorithmsCertExtensionType)
}
func (r AuthenticatorRequest) CertificateAuthorities() ([][]byte, bool) {
for _, e := range r.Extensions {
if e.Type != CertificateAuthoritiesExtensionType {
continue
}
if len(e.Data) < 2 {
return nil, false
}
total := int(readUint16(e.Data[0:2]))
if total < 3 || len(e.Data) != 2+total {
return nil, false
}
var out [][]byte
for off := 2; off < len(e.Data); {
if len(e.Data)-off < 2 {
return nil, false
}
l := int(readUint16(e.Data[off : off+2]))
off += 2
if l == 0 || l > len(e.Data)-off {
return nil, false
}
out = append(out, append([]byte(nil), e.Data[off:off+l]...))
off += l
}
return out, true
}
return nil, false
}
func (r AuthenticatorRequest) OIDFilters() ([]OIDFilter, bool) {
for _, e := range r.Extensions {
if e.Type != OIDFiltersExtensionType {
continue
}
if len(e.Data) < 2 {
return nil, false
}
total := int(readUint16(e.Data[0:2]))
if len(e.Data) != 2+total {
return nil, false
}
var out []OIDFilter
for off := 2; off < len(e.Data); {
if len(e.Data)-off < 1 {
return nil, false
}
oidLen := int(e.Data[off])
off++
if oidLen == 0 || oidLen > len(e.Data)-off {
return nil, false
}
rawOID := append([]byte(nil), e.Data[off:off+oidLen]...)
off += oidLen
var oid asn1.ObjectIdentifier
if _, err := asn1.Unmarshal(rawOID, &oid); err != nil {
return nil, false
}
if len(e.Data)-off < 2 {
return nil, false
}
valLen := int(readUint16(e.Data[off : off+2]))
off += 2
if valLen > len(e.Data)-off {
return nil, false
}
values := append([]byte(nil), e.Data[off:off+valLen]...)
off += valLen
out = append(out, OIDFilter{OID: oid, Values: values})
}
return out, true
}
return nil, false
}
func parseSignatureSchemesExtension(exts []Extension, typ uint16) ([]uint16, bool) {
for _, e := range exts {
if e.Type != typ {
continue
}
if len(e.Data) < 2 {
return nil, false
}
vecLen := int(readUint16(e.Data[0:2]))
if vecLen < 2 || vecLen%2 != 0 || len(e.Data) != 2+vecLen {
return nil, false
}
out := make([]uint16, 0, vecLen/2)
for off := 2; off < len(e.Data); off += 2 {
out = append(out, readUint16(e.Data[off:off+2]))
}
return out, true
}
return nil, false
}
func SignatureAlgorithmsExtension(schemes []uint16) (Extension, error) {
return marshalSignatureSchemesExtension(SignatureAlgorithmsExtensionType, schemes)
}
func SignatureAlgorithmsCertExtension(schemes []uint16) (Extension, error) {
return marshalSignatureSchemesExtension(SignatureAlgorithmsCertExtensionType, schemes)
}
func marshalSignatureSchemesExtension(typ uint16, schemes []uint16) (Extension, error) {
if len(schemes) == 0 {
return Extension{}, ErrInvalidLength
}
if len(schemes) > 0x7fff {
return Extension{}, ErrInvalidLength
}
data := make([]byte, 2+2*len(schemes))
putUint16(data[0:2], uint16(2*len(schemes)))
off := 2
for _, s := range schemes {
putUint16(data[off:off+2], s)
off += 2
}
return Extension{Type: typ, Data: data}, nil
}
+59
View File
@@ -0,0 +1,59 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
import (
"crypto/tls"
"crypto/x509"
"sync"
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
)
type Session struct {
mu sync.Mutex
used map[string]struct{}
}
func NewSession() *Session {
return &Session{used: make(map[string]struct{})}
}
func (s *Session) MarkContextUsed(ctx []byte) error {
if s == nil {
return nil
}
key := string(ctx)
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.used[key]; ok {
return ErrContextReuse
}
s.used[key] = struct{}{}
return nil
}
func (s *Session) CreateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
return createAuthenticator(s, st, role, req, nil, identity, leafEntryExtensions)
}
func (s *Session) ValidateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
return validateAuthenticator(s, st, role, req, nil, nil, authBytes, verifyOpts)
}
func (s *Session) CreateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
return createAuthenticator(s, st, role, req, policy, identity, leafEntryExtensions)
}
func (s *Session) ValidateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
return validateAuthenticator(s, st, role, req, policy, nil, authBytes, verifyOpts)
}
func (s *Session) ValidateAuthenticatorWithAttestation(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) {
return validateAuthenticator(s, st, role, req, nil, &attPolicy, authBytes, verifyOpts)
}
func (s *Session) ValidateAuthenticatorWithPolicies(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) {
return validateAuthenticator(s, st, role, req, policy, &attPolicy, authBytes, verifyOpts)
}
+82
View File
@@ -0,0 +1,82 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/tls"
"fmt"
)
type sigAlg uint8
const (
sigAlgECDSA sigAlg = iota + 1
sigAlgRSAPSS
sigAlgEd25519
)
type sigSchemeInfo struct {
Scheme uint16
Alg sigAlg
Hash crypto.Hash // 0 for Ed25519
}
func signatureSchemeInfo(s uint16) (sigSchemeInfo, error) {
switch s {
case uint16(tls.ECDSAWithP256AndSHA256):
return sigSchemeInfo{s, sigAlgECDSA, crypto.SHA256}, nil
case uint16(tls.PSSWithSHA256):
return sigSchemeInfo{s, sigAlgRSAPSS, crypto.SHA256}, nil
case uint16(tls.Ed25519):
return sigSchemeInfo{s, sigAlgEd25519, 0}, nil
default:
return sigSchemeInfo{}, fmt.Errorf("%w: 0x%04x", ErrUnsupportedSignatureScheme, s)
}
}
func chooseSignatureScheme(priv any, offered []uint16) (uint16, error) {
compat := func(s uint16) bool {
info, err := signatureSchemeInfo(s)
if err != nil {
return false
}
switch info.Alg {
case sigAlgECDSA:
k, ok := priv.(*ecdsa.PrivateKey)
return ok && k.Curve.Params().Name == "P-256"
case sigAlgRSAPSS:
_, ok := priv.(*rsa.PrivateKey)
return ok
case sigAlgEd25519:
_, ok := priv.(ed25519.PrivateKey)
return ok
default:
return false
}
}
if len(offered) > 0 {
for _, s := range offered {
if compat(s) {
return s, nil
}
}
return 0, ErrUnsupportedSignatureScheme
}
if k, ok := priv.(*ecdsa.PrivateKey); ok && k.Curve.Params().Name == "P-256" {
return uint16(tls.ECDSAWithP256AndSHA256), nil
}
if _, ok := priv.(*rsa.PrivateKey); ok {
return uint16(tls.PSSWithSHA256), nil
}
if _, ok := priv.(ed25519.PrivateKey); ok {
return uint16(tls.Ed25519), nil
}
return 0, ErrUnsupportedSignatureScheme
}
+34
View File
@@ -0,0 +1,34 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ea
import (
"crypto"
"crypto/hmac"
"crypto/subtle"
)
func hashConcat(h crypto.Hash, chunks ...[]byte) []byte {
hs := h.New()
for _, c := range chunks {
if len(c) == 0 {
continue
}
hs.Write(c)
}
return hs.Sum(nil)
}
func hmacSum(h crypto.Hash, key, data []byte) []byte {
m := hmac.New(h.New, key)
m.Write(data)
return m.Sum(nil)
}
func constantTimeEqual(a, b []byte) bool {
if len(a) != len(b) {
return false
}
return subtle.ConstantTimeCompare(a, b) == 1
}
+88
View File
@@ -0,0 +1,88 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation
import (
"crypto"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
)
const (
ExporterLabelAttestation = "Attestation"
ExporterLabelAttestationBinding = "Attestation Binding"
)
const ExportedAttestationValueLen = 32
var errNotTLS13 = errors.New("attestation: not TLS 1.3")
func ExportAttestationValue(st *tls.ConnectionState, label string, contextValue []byte) ([]byte, crypto.Hash, error) {
if st.Version != tls.VersionTLS13 {
return nil, 0, errNotTLS13
}
h, err := authenticatorHashTLS13(st.CipherSuite)
if err != nil {
return nil, 0, err
}
out, err := st.ExportKeyingMaterial(label, contextValue, ExportedAttestationValueLen)
if err != nil {
return nil, 0, err
}
return out, h, nil
}
func PublicKeyBytes(leaf *x509.Certificate) ([]byte, error) {
if leaf == nil {
return nil, fmt.Errorf("nil leaf cert")
}
if len(leaf.RawSubjectPublicKeyInfo) > 0 {
return leaf.RawSubjectPublicKeyInfo, nil
}
b, err := x509.MarshalPKIXPublicKey(leaf.PublicKey)
if err != nil {
return nil, err
}
return b, nil
}
func AIKPublicKeyHash(h crypto.Hash, pubKey []byte) []byte {
hs := h.New()
hs.Write(pubKey)
return hs.Sum(nil)
}
func BindingValue(h crypto.Hash, pubKey, exportedValue []byte) []byte {
hs := h.New()
hs.Write(pubKey)
hs.Write(exportedValue)
return hs.Sum(nil)
}
func ComputeBinding(st *tls.ConnectionState, label string, certificateRequestContext []byte, leaf *x509.Certificate) (exportedValue, aikPubHash, binding []byte, err error) {
exportedValue, h, err := ExportAttestationValue(st, label, certificateRequestContext)
if err != nil {
return nil, nil, nil, err
}
pub, err := PublicKeyBytes(leaf)
if err != nil {
return nil, nil, nil, err
}
aikPubHash = AIKPublicKeyHash(h, pub)
binding = BindingValue(h, pub, exportedValue)
return exportedValue, aikPubHash, binding, nil
}
func authenticatorHashTLS13(cipherSuite uint16) (crypto.Hash, error) {
switch cipherSuite {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256:
return crypto.SHA256, nil
case tls.TLS_AES_256_GCM_SHA384:
return crypto.SHA384, nil
default:
return 0, fmt.Errorf("attestation: unknown cipher suite: %s (0x%04x)", tls.CipherSuiteName(cipherSuite), cipherSuite)
}
}
+199
View File
@@ -0,0 +1,199 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"testing"
"time"
)
type stubEvidenceVerifier struct {
called bool
err error
}
func (s *stubEvidenceVerifier) VerifyEvidence(evidence []byte) error {
s.called = true
return s.err
}
type stubResultsVerifier struct {
called bool
err error
}
func (s *stubResultsVerifier) VerifyAttestationResults(results []byte) error {
s.called = true
return s.err
}
func makeCert(t *testing.T) (tls.Certificate, *x509.Certificate) {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "binding"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
t.Fatal(err)
}
leaf, err := x509.ParseCertificate(der)
if err != nil {
t.Fatal(err)
}
return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: priv}, leaf
}
func tls13Client(t *testing.T, cert tls.Certificate) (*tls.Conn, *tls.Conn) {
t.Helper()
srvConf := &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13}
cliConf := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13}
a, b := net.Pipe()
srv := tls.Server(a, srvConf)
cli := tls.Client(b, cliConf)
errCh := make(chan error, 2)
go func() { errCh <- srv.Handshake() }()
go func() { errCh <- cli.Handshake() }()
for i := 0; i < 2; i++ {
if err := <-errCh; err != nil {
t.Fatalf("handshake: %v", err)
}
}
return srv, cli
}
func TestComputeBindingDeterministic(t *testing.T) {
cert, leaf := makeCert(t)
srv, cli := tls13Client(t, cert)
defer srv.Close()
defer cli.Close()
st := cli.ConnectionState()
ctx := []byte{1, 2, 3, 4}
ev1, aik1, b1, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf)
if err != nil {
t.Fatal(err)
}
ev2, aik2, b2, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf)
if err != nil {
t.Fatal(err)
}
if len(ev1) != ExportedAttestationValueLen {
t.Fatalf("unexpected exported len: %d", len(ev1))
}
if !bytes.Equal(ev1, ev2) || !bytes.Equal(aik1, aik2) || !bytes.Equal(b1, b2) {
t.Fatalf("expected deterministic outputs for same conn+context")
}
if bytes.Equal(aik1, b1) {
t.Fatalf("unexpected aik == binding")
}
}
func TestPayloadRoundTrip(t *testing.T) {
payload := Payload{
Version: 1,
MediaType: "application/eat+cwt",
Evidence: []byte("evidence"),
Binder: AttestationBinder{
ExporterLabel: ExporterLabelAttestation,
AIKPubHash: []byte("aik"),
Binding: []byte("binding"),
},
}
raw, err := MarshalPayload(payload)
if err != nil {
t.Fatal(err)
}
parsed, err := ParsePayload(raw)
if err != nil {
t.Fatal(err)
}
if parsed.Version != 1 || parsed.MediaType != "application/eat+cwt" || string(parsed.Evidence) != "evidence" {
t.Fatalf("unexpected parsed payload: %#v", parsed)
}
}
func TestVerifyPayloadSuccess(t *testing.T) {
cert, leaf := makeCert(t)
srv, cli := tls13Client(t, cert)
defer srv.Close()
defer cli.Close()
st := cli.ConnectionState()
ctx := []byte{1, 2, 3, 4}
_, aik, binding, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf)
if err != nil {
t.Fatal(err)
}
ev := &stubEvidenceVerifier{}
rv := &stubResultsVerifier{}
payload := &Payload{
Version: 1,
Evidence: []byte("evidence"),
AttestationResults: []byte("results"),
Binder: AttestationBinder{
ExporterLabel: ExporterLabelAttestation,
AIKPubHash: aik,
Binding: binding,
},
}
verified, err := VerifyPayload(&st, ExporterLabelAttestation, ctx, leaf, payload, VerificationPolicy{
EvidenceVerifier: ev,
ResultsVerifier: rv,
})
if err != nil {
t.Fatal(err)
}
if !verified.BindingVerified || !verified.EvidenceVerified || !verified.ResultsVerified {
t.Fatalf("unexpected verification result: %#v", verified)
}
if !ev.called || !rv.called {
t.Fatalf("expected both verifiers to be called")
}
}
func TestVerifyBinderRejectsMismatch(t *testing.T) {
cert, leaf := makeCert(t)
srv, cli := tls13Client(t, cert)
defer srv.Close()
defer cli.Close()
st := cli.ConnectionState()
ctx := []byte{1, 2, 3, 4}
_, aik, binding, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf)
if err != nil {
t.Fatal(err)
}
binding[0] ^= 0xff
err = VerifyBinder(&st, ExporterLabelAttestation, ctx, leaf, AttestationBinder{
AIKPubHash: aik,
Binding: binding,
})
if err != ErrBindingMismatch {
t.Fatalf("got %v, want %v", err, ErrBindingMismatch)
}
}
+87
View File
@@ -0,0 +1,87 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation
import (
"encoding/json"
"errors"
)
var (
ErrMalformedPayload = errors.New("attestation: malformed payload")
ErrMissingStatement = errors.New("attestation: missing evidence or attestation results")
ErrMissingBinder = errors.New("attestation: missing attestation binder")
ErrAIKPubHashMismatch = errors.New("attestation: AIK public key hash mismatch")
ErrBindingMismatch = errors.New("attestation: attestation binding mismatch")
ErrEvidenceVerificationMissing = errors.New("attestation: evidence verifier not configured")
ErrResultsVerificationMissing = errors.New("attestation: attestation results verifier not configured")
)
// Payload models the attestation document carried inside the EA certificate-entry extension.
// It intentionally separates the attestation statement from the TLS binding material.
type Payload struct {
Version int `json:"version"`
MediaType string `json:"media_type,omitempty"`
Evidence []byte `json:"evidence,omitempty"`
AttestationResults []byte `json:"attestation_results,omitempty"`
Binder AttestationBinder `json:"binder"`
}
type AttestationBinder struct {
ExporterLabel string `json:"exporter_label,omitempty"`
AIKPubHash []byte `json:"aik_pub_hash,omitempty"`
Binding []byte `json:"binding,omitempty"`
}
type VerifiedPayload struct {
Payload *Payload
EvidenceVerified bool
ResultsVerified bool
BindingVerified bool
UsedExporterLabel string
}
func (p *Payload) Validate() error {
if p == nil {
return ErrMalformedPayload
}
if len(p.Evidence) == 0 && len(p.AttestationResults) == 0 {
return ErrMissingStatement
}
if len(p.Binder.AIKPubHash) == 0 || len(p.Binder.Binding) == 0 {
return ErrMissingBinder
}
return nil
}
func (p *Payload) NormalizedExporterLabel(defaultLabel string) string {
if p == nil || p.Binder.ExporterLabel == "" {
return defaultLabel
}
return p.Binder.ExporterLabel
}
func MarshalPayload(p Payload) ([]byte, error) {
if err := p.Validate(); err != nil {
return nil, err
}
if p.Version == 0 {
p.Version = 1
}
return json.Marshal(p)
}
func ParsePayload(raw []byte) (*Payload, error) {
if len(raw) == 0 {
return nil, ErrMalformedPayload
}
var p Payload
if err := json.Unmarshal(raw, &p); err != nil {
return nil, ErrMalformedPayload
}
if err := p.Validate(); err != nil {
return nil, err
}
return &p, nil
}
+75
View File
@@ -0,0 +1,75 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation
import (
"crypto/subtle"
"crypto/tls"
"crypto/x509"
)
type EvidenceVerifier interface {
VerifyEvidence(evidence []byte) error
}
type ResultsVerifier interface {
VerifyAttestationResults(results []byte) error
}
type VerificationPolicy struct {
EvidenceVerifier EvidenceVerifier
ResultsVerifier ResultsVerifier
}
func VerifyPayload(st *tls.ConnectionState, defaultLabel string, certificateRequestContext []byte, leaf *x509.Certificate, payload *Payload, policy VerificationPolicy) (*VerifiedPayload, error) {
if err := payload.Validate(); err != nil {
return nil, err
}
verified := &VerifiedPayload{
Payload: payload,
UsedExporterLabel: payload.NormalizedExporterLabel(defaultLabel),
}
if len(payload.Evidence) > 0 && policy.EvidenceVerifier != nil {
if err := policy.EvidenceVerifier.VerifyEvidence(payload.Evidence); err != nil {
return nil, err
}
verified.EvidenceVerified = true
} else if len(payload.Evidence) > 0 {
return nil, ErrEvidenceVerificationMissing
}
if len(payload.AttestationResults) > 0 && policy.ResultsVerifier != nil {
if err := policy.ResultsVerifier.VerifyAttestationResults(payload.AttestationResults); err != nil {
return nil, err
}
verified.ResultsVerified = true
} else if len(payload.AttestationResults) > 0 {
return nil, ErrResultsVerificationMissing
}
if err := VerifyBinder(st, verified.UsedExporterLabel, certificateRequestContext, leaf, payload.Binder); err != nil {
return nil, err
}
verified.BindingVerified = true
return verified, nil
}
func VerifyBinder(st *tls.ConnectionState, label string, certificateRequestContext []byte, leaf *x509.Certificate, binder AttestationBinder) error {
exportedValue, aikPubHash, binding, err := ComputeBinding(st, label, certificateRequestContext, leaf)
if err != nil {
return err
}
_ = exportedValue
if !equalBytes(aikPubHash, binder.AIKPubHash) {
return ErrAIKPubHashMismatch
}
if !equalBytes(binding, binder.Binding) {
return ErrBindingMismatch
}
return nil
}
func equalBytes(a, b []byte) bool {
return subtle.ConstantTimeCompare(a, b) == 1
}
+92
View File
@@ -0,0 +1,92 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"fmt"
"os"
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
cocosattestation "github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/veraison/corim/corim"
)
type policyEvidenceVerifier struct {
policyPath string
}
func NewEvidenceVerifier(policyPath string) eaattestation.EvidenceVerifier {
return &policyEvidenceVerifier{policyPath: policyPath}
}
func (v *policyEvidenceVerifier) VerifyEvidence(evidence []byte) error {
if v.policyPath == "" {
return fmt.Errorf("atls: attestation policy path is not set")
}
claims, err := eat.DecodeCBOR(evidence, nil)
if err != nil {
return fmt.Errorf("atls: failed to decode EAT evidence: %w", err)
}
manifest, err := loadCoRIM(v.policyPath)
if err != nil {
return err
}
verifier, err := platformVerifier(platformTypeFromClaims(claims.PlatformType))
if err != nil {
return err
}
return verifier.VerifyWithCoRIM(claims.RawReport, manifest)
}
func loadCoRIM(path string) (*corim.UnsignedCorim, error) {
corimBytes, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("atls: failed to read CoRIM file: %w", err)
}
var sc corim.SignedCorim
if err := sc.FromCOSE(corimBytes); err == nil {
return &sc.UnsignedCorim, nil
}
var uc corim.UnsignedCorim
if err := uc.FromCBOR(corimBytes); err != nil {
return nil, fmt.Errorf("atls: failed to parse CoRIM: %w", err)
}
return &uc, nil
}
func platformTypeFromClaims(name string) cocosattestation.PlatformType {
switch name {
case "SNP":
return cocosattestation.SNP
case "TDX":
return cocosattestation.TDX
case "vTPM":
return cocosattestation.VTPM
case "SNP-vTPM":
return cocosattestation.SNPvTPM
case "Azure":
return cocosattestation.Azure
default:
return cocosattestation.NoCC
}
}
func platformVerifier(platformType cocosattestation.PlatformType) (cocosattestation.Verifier, error) {
switch platformType {
case cocosattestation.SNP, cocosattestation.SNPvTPM, cocosattestation.VTPM:
return vtpm.NewVerifier(nil), nil
case cocosattestation.Azure:
return azure.NewVerifier(nil), nil
case cocosattestation.TDX:
return tdx.NewVerifier(), nil
default:
return nil, fmt.Errorf("atls: unsupported platform type: %d", platformType)
}
}
+280
View File
@@ -0,0 +1,280 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internaltransport
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"time"
"github.com/ultravioletrs/cocos/pkg/atls/ea"
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
)
type Conn struct {
*tls.Conn
Request *ea.AuthenticatorRequest
ValidationResult *ea.ValidationResult
}
type ClientConfig struct {
TLSConfig *tls.Config
Session *ea.Session
VerifyOptions *x509.VerifyOptions
AttestationPolicy eaattestation.VerificationPolicy
Request *ea.AuthenticatorRequest
RequestBuilder func() (*ea.AuthenticatorRequest, error)
}
type ServerConfig struct {
TLSConfig *tls.Config
Session *ea.Session
Identity tls.Certificate
BuildLeafExtensions func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) ([]ea.Extension, error)
}
func Dial(network, address string, cfg *ClientConfig) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, address, cfg)
}
func DialContext(ctx context.Context, network, address string, cfg *ClientConfig) (*Conn, error) {
return DialContextWithDialer(ctx, new(net.Dialer), network, address, cfg)
}
func DialWithDialer(d *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) {
if cfg == nil || cfg.TLSConfig == nil {
return nil, fmt.Errorf("atls: missing client TLS config")
}
rawConn, err := d.Dial(network, address)
if err != nil {
return nil, err
}
tlsConn := tls.Client(rawConn, cfg.TLSConfig.Clone())
conn, err := Client(tlsConn, cfg)
if err != nil {
_ = tlsConn.Close()
return nil, err
}
return conn, nil
}
func DialContextWithDialer(ctx context.Context, d *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) {
if cfg == nil || cfg.TLSConfig == nil {
return nil, fmt.Errorf("atls: missing client TLS config")
}
if ctx == nil {
return nil, fmt.Errorf("atls: missing client context")
}
rawConn, err := d.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
tlsConn := tls.Client(rawConn, cfg.TLSConfig.Clone())
conn, err := ClientContext(ctx, tlsConn, cfg)
if err != nil {
_ = tlsConn.Close()
return nil, err
}
return conn, nil
}
func Client(tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) {
return ClientContext(context.Background(), tlsConn, cfg)
}
func ClientContext(ctx context.Context, tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) {
if cfg == nil {
return nil, fmt.Errorf("atls: missing client config")
}
if ctx == nil {
return nil, fmt.Errorf("atls: missing client context")
}
var res *Conn
err := withConnContext(ctx, tlsConn, func() error {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return err
}
req, err := buildRequest(cfg)
if err != nil {
return err
}
reqBytes, err := req.Marshal()
if err != nil {
return err
}
if err := writeFrame(tlsConn, frameTypeRequest, reqBytes); err != nil {
return err
}
frameType, authBytes, err := readFrame(tlsConn)
if err != nil {
return err
}
if frameType != frameTypeAuthenticator {
return fmt.Errorf("atls: unexpected frame type %d", frameType)
}
st := tlsConn.ConnectionState()
var validation *ea.ValidationResult
if cfg.Session != nil {
validation, err = cfg.Session.ValidateAuthenticatorWithAttestation(&st, ea.RoleServer, req, authBytes, cfg.VerifyOptions, cfg.AttestationPolicy)
} else {
validation, err = ea.ValidateAuthenticatorWithAttestation(&st, ea.RoleServer, req, authBytes, cfg.VerifyOptions, cfg.AttestationPolicy)
}
if err != nil {
return err
}
res = &Conn{Conn: tlsConn, Request: req, ValidationResult: validation}
return nil
})
if err != nil {
return nil, err
}
return res, nil
}
func Server(tlsConn *tls.Conn, cfg *ServerConfig) (*Conn, error) {
if cfg == nil {
return nil, fmt.Errorf("atls: missing server config")
}
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
frameType, reqBytes, err := readFrame(tlsConn)
if err != nil {
return nil, err
}
if frameType != frameTypeRequest {
return nil, fmt.Errorf("atls: unexpected frame type %d", frameType)
}
req, rest, err := ea.UnmarshalAuthenticatorRequest(reqBytes)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("atls: trailing request bytes")
}
st := tlsConn.ConnectionState()
identity, err := resolveIdentity(cfg)
if err != nil {
return nil, err
}
exts, err := buildServerExtensions(cfg, &st, &req, identity)
if err != nil {
return nil, err
}
var authBytes []byte
if cfg.Session != nil {
authBytes, err = cfg.Session.CreateAuthenticator(&st, ea.RoleServer, &req, identity, exts)
} else {
authBytes, err = ea.CreateAuthenticator(&st, ea.RoleServer, &req, identity, exts)
}
if err != nil {
return nil, err
}
if err := writeFrame(tlsConn, frameTypeAuthenticator, authBytes); err != nil {
return nil, err
}
return &Conn{Conn: tlsConn, Request: &req}, nil
}
func buildRequest(cfg *ClientConfig) (*ea.AuthenticatorRequest, error) {
if cfg.RequestBuilder != nil {
return cfg.RequestBuilder()
}
if cfg.Request != nil {
return cfg.Request, nil
}
ctx, err := ea.NewRandomContext(32)
if err != nil {
return nil, err
}
sigExt, err := ea.SignatureAlgorithmsExtension([]uint16{uint16(tls.ECDSAWithP256AndSHA256)})
if err != nil {
return nil, err
}
return &ea.AuthenticatorRequest{
Type: ea.HandshakeTypeClientCertificateRequest,
Context: ctx,
Extensions: []ea.Extension{
sigExt,
ea.CMWAttestationOfferExtension(),
},
}, nil
}
func resolveIdentity(cfg *ServerConfig) (tls.Certificate, error) {
if len(cfg.Identity.Certificate) > 0 && cfg.Identity.PrivateKey != nil {
return cfg.Identity, nil
}
if cfg.TLSConfig != nil && len(cfg.TLSConfig.Certificates) > 0 {
return cfg.TLSConfig.Certificates[0], nil
}
return tls.Certificate{}, fmt.Errorf("atls: missing server identity")
}
func buildServerExtensions(cfg *ServerConfig, st *tls.ConnectionState, req *ea.AuthenticatorRequest, identity tls.Certificate) ([]ea.Extension, error) {
if cfg.BuildLeafExtensions == nil {
return nil, nil
}
if len(identity.Certificate) == 0 {
return nil, fmt.Errorf("atls: missing server leaf certificate")
}
leaf, err := x509.ParseCertificate(identity.Certificate[0])
if err != nil {
return nil, err
}
return cfg.BuildLeafExtensions(st, req, leaf)
}
func withConnContext(ctx context.Context, conn interface{ SetDeadline(time.Time) error }, fn func() error) error {
if ctx == nil {
return fn()
}
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return err
}
defer func() {
_ = conn.SetDeadline(time.Time{})
}()
} else {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = conn.SetDeadline(time.Now())
case <-done:
}
}()
defer func() {
close(done)
_ = conn.SetDeadline(time.Time{})
}()
}
err := fn()
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
return err
}
+100
View File
@@ -0,0 +1,100 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internaltransport
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"testing"
"time"
)
func selfSignedCert(t *testing.T) tls.Certificate {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "internal-transport"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
t.Fatal(err)
}
return tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
}
}
func TestServerAllowsIdentityWithoutTLSConfig(t *testing.T) {
cert := selfSignedCert(t)
a, b := net.Pipe()
serverTLS := tls.Server(a, &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
})
clientTLS := tls.Client(b, &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
})
type result struct {
conn *Conn
err error
}
serverCh := make(chan result, 1)
clientCh := make(chan result, 1)
go func() {
conn, err := Server(serverTLS, &ServerConfig{
Identity: cert,
})
serverCh <- result{conn: conn, err: err}
}()
go func() {
conn, err := Client(clientTLS, &ClientConfig{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
},
})
clientCh <- result{conn: conn, err: err}
}()
srvRes := <-serverCh
cliRes := <-clientCh
if srvRes.err != nil {
t.Fatalf("server failed: %v", srvRes.err)
}
if cliRes.err != nil {
t.Fatalf("client failed: %v", cliRes.err)
}
defer srvRes.conn.Close()
defer cliRes.conn.Close()
}
+49
View File
@@ -0,0 +1,49 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internaltransport
import (
"crypto/tls"
"fmt"
"net"
)
type Listener struct {
raw net.Listener
cfg *ServerConfig
}
func Listen(network, address string, cfg *ServerConfig) (*Listener, error) {
if cfg == nil || cfg.TLSConfig == nil {
return nil, fmt.Errorf("atls: missing server TLS config")
}
raw, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{raw: raw, cfg: cfg}, nil
}
func (l *Listener) Accept() (net.Conn, error) {
rawConn, err := l.raw.Accept()
if err != nil {
return nil, err
}
tlsConn := tls.Server(rawConn, l.cfg.TLSConfig.Clone())
conn, err := Server(tlsConn, l.cfg)
if err != nil {
_ = tlsConn.Close()
return nil, err
}
return conn, nil
}
func (l *Listener) Close() error {
return l.raw.Close()
}
func (l *Listener) Addr() net.Addr {
return l.raw.Addr()
}
+62
View File
@@ -0,0 +1,62 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internaltransport
import (
"encoding/binary"
"fmt"
"io"
)
const (
frameTypeRequest uint8 = iota + 1
frameTypeAuthenticator
maxFramePayloadLen = 16 << 20 // 16 MiB
)
func writeFrame(w io.Writer, typ uint8, payload []byte) error {
if len(payload) > maxFramePayloadLen {
return fmt.Errorf("atls: frame payload too large: %d", len(payload))
}
header := make([]byte, 5)
header[0] = typ
binary.BigEndian.PutUint32(header[1:5], uint32(len(payload)))
if err := writeAll(w, header); err != nil {
return err
}
if len(payload) == 0 {
return nil
}
return writeAll(w, payload)
}
func writeAll(w io.Writer, b []byte) error {
for len(b) > 0 {
n, err := w.Write(b)
if err != nil {
return err
}
b = b[n:]
}
return nil
}
func readFrame(r io.Reader) (uint8, []byte, error) {
header := make([]byte, 5)
if _, err := io.ReadFull(r, header); err != nil {
return 0, nil, err
}
if header[0] == 0 {
return 0, nil, fmt.Errorf("atls: invalid frame type")
}
n := binary.BigEndian.Uint32(header[1:5])
if n > maxFramePayloadLen {
return 0, nil, fmt.Errorf("atls: frame payload too large: %d", n)
}
payload := make([]byte, n)
if _, err := io.ReadFull(r, payload); err != nil {
return 0, nil, err
}
return header[0], payload, nil
}
@@ -0,0 +1,40 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internaltransport
import (
"bytes"
"encoding/binary"
"strings"
"testing"
)
func TestReadFrameRejectsOversizedPayload(t *testing.T) {
var buf bytes.Buffer
header := make([]byte, 5)
header[0] = frameTypeRequest
binary.BigEndian.PutUint32(header[1:5], maxFramePayloadLen+1)
if _, err := buf.Write(header); err != nil {
t.Fatal(err)
}
_, _, err := readFrame(&buf)
if err == nil {
t.Fatal("expected oversized frame error")
}
if !strings.Contains(err.Error(), "frame payload too large") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestWriteFrameRejectsOversizedPayload(t *testing.T) {
payload := make([]byte, maxFramePayloadLen+1)
err := writeFrame(bytes.NewBuffer(nil), frameTypeAuthenticator, payload)
if err == nil {
t.Fatal("expected oversized frame error")
}
if !strings.Contains(err.Error(), "frame payload too large") {
t.Fatalf("unexpected error: %v", err)
}
}
+25 -75
View File
@@ -1,103 +1,53 @@
// Code generated manually for tests.
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"crypto/tls"
"crypto/x509"
mock "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/pkg/atls/ea"
)
// NewCertificateProvider creates a new instance of CertificateProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewCertificateProvider(t interface {
mock.TestingT
Cleanup(func())
}) *CertificateProvider {
mock := &CertificateProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// CertificateProvider is an autogenerated mock type for the CertificateProvider type
type CertificateProvider struct {
mock.Mock
}
type CertificateProvider_Expecter struct {
mock *mock.Mock
}
func (_m *CertificateProvider) BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error) {
ret := _m.Called(st, req, leaf)
func (_m *CertificateProvider) EXPECT() *CertificateProvider_Expecter {
return &CertificateProvider_Expecter{mock: &_m.Mock}
}
// GetCertificate provides a mock function for the type CertificateProvider
func (_mock *CertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
ret := _mock.Called(clientHello)
if len(ret) == 0 {
panic("no return value specified for GetCertificate")
}
var r0 *tls.Certificate
var r0 []ea.Extension
var r1 error
if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) (*tls.Certificate, error)); ok {
return returnFunc(clientHello)
if rf, ok := ret.Get(0).(func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) ([]ea.Extension, error)); ok {
return rf(st, req, leaf)
}
if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) *tls.Certificate); ok {
r0 = returnFunc(clientHello)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*tls.Certificate)
}
if ret.Get(0) != nil {
r0 = ret.Get(0).([]ea.Extension)
}
if returnFunc, ok := ret.Get(1).(func(*tls.ClientHelloInfo) error); ok {
r1 = returnFunc(clientHello)
if rf, ok := ret.Get(1).(func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) error); ok {
r1 = rf(st, req, leaf)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CertificateProvider_GetCertificate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCertificate'
type CertificateProvider_GetCertificate_Call struct {
*mock.Call
}
func NewCertificateProvider(t interface {
mock.TestingT
Cleanup(func())
}) *CertificateProvider {
mockProvider := &CertificateProvider{}
mockProvider.Mock.Test(t)
// GetCertificate is a helper method to define mock.On call
// - clientHello *tls.ClientHelloInfo
func (_e *CertificateProvider_Expecter) GetCertificate(clientHello interface{}) *CertificateProvider_GetCertificate_Call {
return &CertificateProvider_GetCertificate_Call{Call: _e.mock.On("GetCertificate", clientHello)}
}
func (_c *CertificateProvider_GetCertificate_Call) Run(run func(clientHello *tls.ClientHelloInfo)) *CertificateProvider_GetCertificate_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 *tls.ClientHelloInfo
if args[0] != nil {
arg0 = args[0].(*tls.ClientHelloInfo)
}
run(
arg0,
)
t.Cleanup(func() {
mockProvider.AssertExpectations(t)
})
return _c
}
func (_c *CertificateProvider_GetCertificate_Call) Return(certificate *tls.Certificate, err error) *CertificateProvider_GetCertificate_Call {
_c.Call.Return(certificate, err)
return _c
}
func (_c *CertificateProvider_GetCertificate_Call) RunAndReturn(run func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)) *CertificateProvider_GetCertificate_Call {
_c.Call.Return(run)
return _c
return mockProvider
}
+84
View File
@@ -0,0 +1,84 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"context"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/absmach/certs/sdk"
"github.com/ultravioletrs/cocos/pkg/atls/ea"
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
cocosattestation "github.com/ultravioletrs/cocos/pkg/attestation"
attestationclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
)
// CertificateProvider is kept for compatibility with existing cocos call sites.
// In the EA-based implementation it provides the leaf certificate-entry extensions
// carried in the exported authenticator instead of generating TLS certificates.
type CertificateProvider interface {
BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error)
}
type provider struct {
attClient attestationclient.Client
platformType cocosattestation.PlatformType
}
func NewProvider(attClient attestationclient.Client, platformType cocosattestation.PlatformType, _ string, _ string, _ sdk.SDK) (CertificateProvider, error) {
if attClient == nil {
return nil, fmt.Errorf("atls: missing attestation client")
}
if platformType == cocosattestation.NoCC {
return nil, fmt.Errorf("atls: confidential computing platform not available")
}
return &provider{
attClient: attClient,
platformType: platformType,
}, nil
}
func (p *provider) BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error) {
if st == nil || req == nil || leaf == nil {
return nil, fmt.Errorf("atls: missing state, request, or leaf certificate")
}
exportedValue, aikPubHash, binding, err := eaattestation.ComputeBinding(st, eaattestation.ExporterLabelAttestation, req.Context, leaf)
if err != nil {
return nil, err
}
reportData := sha512.Sum512(binding)
nonceBytes := sha256.Sum256(exportedValue)
var nonce [32]byte
copy(nonce[:], nonceBytes[:])
evidence, err := p.attClient.GetAttestation(context.Background(), reportData, nonce, p.platformType)
if err != nil {
return nil, fmt.Errorf("atls: failed to fetch attestation evidence: %w", err)
}
payloadBytes, err := eaattestation.MarshalPayload(eaattestation.Payload{
Version: 1,
MediaType: "application/eat+cwt",
Evidence: evidence,
Binder: eaattestation.AttestationBinder{
ExporterLabel: eaattestation.ExporterLabelAttestation,
AIKPubHash: aikPubHash,
Binding: binding,
},
})
if err != nil {
return nil, err
}
ext, err := ea.CMWAttestationDataExtension(payloadBytes)
if err != nil {
return nil, err
}
return []ea.Extension{ext}, nil
}
+93
View File
@@ -0,0 +1,93 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"time"
)
// BuildServerTLSConfig prepares the base TLS configuration used by the EA/aTLS
// transport. If no certificate/key pair is configured, it falls back to an
// ephemeral self-signed identity bound by the exported authenticator.
func BuildServerTLSConfig(certFile, keyFile, serverCAFile, clientCAFile string) (*tls.Config, tls.Certificate, bool, error) {
if certFile != "" || keyFile != "" {
tlsSetup, err := setupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile)
if err != nil {
return nil, tls.Certificate{}, false, err
}
tlsSetup.config.MinVersion = tls.VersionTLS13
return tlsSetup.config, tlsSetup.config.Certificates[0], tlsSetup.mtls, nil
}
identity, err := generateEphemeralIdentity()
if err != nil {
return nil, tls.Certificate{}, false, fmt.Errorf("failed to generate ephemeral TLS identity: %w", err)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{identity},
}
mtls, err := configureCertificateAuthorities(tlsConfig, serverCAFile, clientCAFile)
if err != nil {
return nil, tls.Certificate{}, false, err
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
return tlsConfig, identity, mtls, nil
}
func generateEphemeralIdentity() (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
serialLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialLimit)
if err != nil {
return tls.Certificate{}, err
}
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "cocos-atls-ephemeral",
Organization: []string{"Ultraviolet"},
},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
leaf, err := x509.ParseCertificate(der)
if err != nil {
return tls.Certificate{}, err
}
return tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
Leaf: leaf,
}, nil
}
+104
View File
@@ -0,0 +1,104 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"strings"
)
type tlsSetupResult struct {
config *tls.Config
mtls bool
}
func readFileOrData(input string) ([]byte, error) {
if len(input) < 1000 && !strings.Contains(input, "\n") {
data, err := os.ReadFile(input)
if err == nil {
return data, nil
}
return nil, err
}
return []byte(input), nil
}
func loadX509KeyPair(certFile, keyFile string) (tls.Certificate, error) {
cert, err := readFileOrData(certFile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read cert: %w", err)
}
key, err := readFileOrData(keyFile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read key: %w", err)
}
return tls.X509KeyPair(cert, key)
}
func loadCertFile(certFile string) ([]byte, error) {
if certFile == "" {
return []byte{}, nil
}
return readFileOrData(certFile)
}
func configureCertificateAuthorities(tlsConfig *tls.Config, serverCAFile, clientCAFile string) (bool, error) {
rootCA, err := loadCertFile(serverCAFile)
if err != nil {
return false, fmt.Errorf("failed to load server ca file: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return false, fmt.Errorf("failed to append server ca to tls.Config")
}
}
clientCA, err := loadCertFile(clientCAFile)
if err != nil {
return false, fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) == 0 {
return false, nil
}
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return false, fmt.Errorf("failed to append client ca to tls.Config")
}
return true, nil
}
func setupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile string) (*tlsSetupResult, error) {
certificate, err := loadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
}
mtls, err := configureCertificateAuthorities(tlsConfig, serverCAFile, clientCAFile)
if err != nil {
return nil, err
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
return &tlsSetupResult{config: tlsConfig, mtls: mtls}, nil
}
+83
View File
@@ -0,0 +1,83 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"github.com/ultravioletrs/cocos/pkg/atls/ea"
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
internaltransport "github.com/ultravioletrs/cocos/pkg/atls/internal_transport"
)
type Conn = internaltransport.Conn
type Listener = internaltransport.Listener
type ClientConfig = internaltransport.ClientConfig
type ServerConfig = internaltransport.ServerConfig
type AuthenticatorRequest = ea.AuthenticatorRequest
func Dial(network, address string, cfg *ClientConfig) (*Conn, error) {
return internaltransport.Dial(network, address, cfg)
}
func DialContext(ctx context.Context, network, address string, cfg *ClientConfig) (*Conn, error) {
return internaltransport.DialContext(ctx, network, address, cfg)
}
func DialWithDialer(dialer *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) {
return internaltransport.DialWithDialer(dialer, network, address, cfg)
}
func Client(tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) {
return internaltransport.Client(tlsConn, cfg)
}
func Server(tlsConn *tls.Conn, cfg *ServerConfig) (*Conn, error) {
return internaltransport.Server(tlsConn, cfg)
}
func Listen(network, address string, cfg *ServerConfig) (*Listener, error) {
return internaltransport.Listen(network, address, cfg)
}
func NewRequest(context []byte) (*ea.AuthenticatorRequest, error) {
sigExt, err := ea.SignatureAlgorithmsExtension([]uint16{uint16(tls.ECDSAWithP256AndSHA256)})
if err != nil {
return nil, err
}
return &ea.AuthenticatorRequest{
Type: ea.HandshakeTypeClientCertificateRequest,
Context: context,
Extensions: []ea.Extension{
sigExt,
ea.CMWAttestationOfferExtension(),
},
}, nil
}
func NewRandomRequest(contextLen int) (*ea.AuthenticatorRequest, error) {
context, err := ea.NewRandomContext(contextLen)
if err != nil {
return nil, err
}
return NewRequest(context)
}
func VerifyOptionsFromTLSConfig(cfg *tls.Config) *x509.VerifyOptions {
if cfg == nil || cfg.InsecureSkipVerify || cfg.RootCAs == nil {
return nil
}
return &x509.VerifyOptions{Roots: cfg.RootCAs}
}
func VerificationPolicyFromEvidenceVerifier(v eaattestation.EvidenceVerifier) eaattestation.VerificationPolicy {
return eaattestation.VerificationPolicy{EvidenceVerifier: v}
}
+78
View File
@@ -0,0 +1,78 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"crypto/tls"
"crypto/x509"
"testing"
)
func TestVerifyOptionsFromTLSConfig(t *testing.T) {
t.Run("nil config", func(t *testing.T) {
if got := VerifyOptionsFromTLSConfig(nil); got != nil {
t.Fatalf("expected nil verify options, got %#v", got)
}
})
t.Run("skip verify disables ea chain validation", func(t *testing.T) {
got := VerifyOptionsFromTLSConfig(&tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
})
if got != nil {
t.Fatalf("expected nil verify options for insecure skip verify, got %#v", got)
}
})
t.Run("missing roots disables ea chain validation", func(t *testing.T) {
got := VerifyOptionsFromTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS13,
})
if got != nil {
t.Fatalf("expected nil verify options when roots are not configured, got %#v", got)
}
})
t.Run("configured roots are propagated", func(t *testing.T) {
roots := x509.NewCertPool()
got := VerifyOptionsFromTLSConfig(&tls.Config{
RootCAs: roots,
MinVersion: tls.VersionTLS13,
})
if got == nil {
t.Fatal("expected verify options, got nil")
}
if got.Roots != roots {
t.Fatal("expected verify options to reuse configured root CAs")
}
})
}
func TestNewRandomRequest(t *testing.T) {
req1, err := NewRandomRequest(32)
if err != nil {
t.Fatalf("first request failed: %v", err)
}
req2, err := NewRandomRequest(32)
if err != nil {
t.Fatalf("second request failed: %v", err)
}
if len(req1.Context) != 32 {
t.Fatalf("expected first request context length 32, got %d", len(req1.Context))
}
if len(req2.Context) != 32 {
t.Fatalf("expected second request context length 32, got %d", len(req2.Context))
}
if len(req1.Extensions) == 0 {
t.Fatal("expected first request to carry extensions")
}
if len(req2.Extensions) == 0 {
t.Fatal("expected second request to carry extensions")
}
if string(req1.Context) == string(req2.Context) {
t.Fatal("expected random request contexts to differ")
}
}
+32 -1
View File
@@ -3,11 +3,18 @@
package clients
import "time"
import (
"encoding/hex"
"errors"
"fmt"
"time"
)
var (
_ ClientConfiguration = (*AttestedClientConfig)(nil)
_ ClientConfiguration = (*StandardClientConfig)(nil)
ErrInvalidAttestationRequestContext = errors.New("invalid attestation request context")
)
type ClientConfiguration interface {
@@ -29,6 +36,13 @@ type AttestedClientConfig struct {
AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""`
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
ProductName string `env:"PRODUCT_NAME" envDefault:"Milan"`
// AttestationRequestContextHex, when set, is decoded from hex and used as
// the exported authenticator certificate_request_context. This lets the
// caller provide the background-check freshness value directly.
AttestationRequestContextHex string `env:"ATTESTATION_REQUEST_CONTEXT" envDefault:""`
// AttestationRequestContext allows callers inside the same process to pass
// raw request-context bytes directly instead of using the hex string form.
AttestationRequestContext []byte `env:"-"`
}
func (c AttestedClientConfig) Config() StandardClientConfig {
@@ -38,3 +52,20 @@ func (c AttestedClientConfig) Config() StandardClientConfig {
func (c StandardClientConfig) Config() StandardClientConfig {
return c
}
func (c AttestedClientConfig) RequestContext() ([]byte, error) {
if len(c.AttestationRequestContext) > 0 {
return append([]byte(nil), c.AttestationRequestContext...), nil
}
if c.AttestationRequestContextHex == "" {
return nil, nil
}
requestContext, err := hex.DecodeString(c.AttestationRequestContextHex)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrInvalidAttestationRequestContext, err)
}
if len(requestContext) == 0 {
return nil, fmt.Errorf("%w: decoded value is empty", ErrInvalidAttestationRequestContext)
}
return requestContext, nil
}
+32
View File
@@ -103,6 +103,22 @@ func TestNewClient(t *testing.T) {
wantErr: false,
err: nil,
},
{
name: "Success agent client with aTLS and custom request context",
agentCfg: clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
ClientCert: clientCertFile,
ClientKey: clientKeyFile,
},
AttestedTLS: true,
AttestationPolicy: policyFile.Name(),
AttestationRequestContextHex: "01020304",
},
wantErr: false,
err: nil,
},
{
name: "Failed agent client with aTLS",
agentCfg: clients.AttestedClientConfig{
@@ -118,6 +134,22 @@ func TestNewClient(t *testing.T) {
wantErr: true,
err: fmt.Errorf("failed to stat attestation policy file"),
},
{
name: "Failed agent client with invalid attestation request context",
agentCfg: clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
ClientCert: clientCertFile,
ClientKey: clientKeyFile,
},
AttestedTLS: true,
AttestationPolicy: policyFile.Name(),
AttestationRequestContextHex: "xyz",
},
wantErr: true,
err: clients.ErrInvalidAttestationRequestContext,
},
{
name: "Fail with invalid ServerCAFile",
cfg: clients.StandardClientConfig{
+53 -1
View File
@@ -4,7 +4,13 @@
package grpc
import (
"context"
stdtls "crypto/tls"
"net"
"strings"
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
@@ -77,7 +83,43 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, e
return nil, security, err
}
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(result.Config)))
tlsConfig := result.Config.Clone()
tlsConfig.MinVersion = stdtls.VersionTLS13
tlsConfig.NextProtos = []string{"h2"}
atlsConfig := &atls.ClientConfig{
TLSConfig: tlsConfig,
VerifyOptions: atls.VerifyOptionsFromTLSConfig(tlsConfig),
AttestationPolicy: atls.VerificationPolicyFromEvidenceVerifier(atls.NewEvidenceVerifier(agcfg.AttestationPolicy)),
}
requestContext, err := agcfg.RequestContext()
if err != nil {
return nil, security, err
}
if len(requestContext) > 0 {
req, err := atls.NewRequest(requestContext)
if err != nil {
return nil, security, err
}
atlsConfig.Request = req
} else {
atlsConfig.RequestBuilder = func() (*atls.AuthenticatorRequest, error) {
return atls.NewRandomRequest(32)
}
}
opts = append(opts,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
network, target := dialTarget(addr)
conn, err := atls.DialContext(ctx, network, target, atlsConfig)
if err != nil {
return nil, err
}
return conn, nil
}),
)
security = result.Security
} else {
conf := cfg.Config()
@@ -96,6 +138,16 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, e
return conn, security, nil
}
func dialTarget(addr string) (string, string) {
if strings.HasPrefix(addr, "unix://") {
return "unix", strings.TrimPrefix(addr, "unix://")
}
if strings.HasPrefix(addr, "/") {
return "unix", addr
}
return "tcp", addr
}
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) {
result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey)
if err != nil {
+42 -1
View File
@@ -4,9 +4,14 @@
package http
import (
"context"
stdtls "crypto/tls"
"net"
"net/http"
"strings"
"time"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
)
@@ -70,7 +75,33 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Secu
return nil, security, err
}
transport.TLSClientConfig = result.Config
tlsConfig := result.Config.Clone()
tlsConfig.MinVersion = stdtls.VersionTLS13
atlsConfig := &atls.ClientConfig{
TLSConfig: tlsConfig,
VerifyOptions: atls.VerifyOptionsFromTLSConfig(tlsConfig),
AttestationPolicy: atls.VerificationPolicyFromEvidenceVerifier(atls.NewEvidenceVerifier(agcfg.AttestationPolicy)),
}
requestContext, err := agcfg.RequestContext()
if err != nil {
return nil, security, err
}
if len(requestContext) > 0 {
req, err := atls.NewRequest(requestContext)
if err != nil {
return nil, security, err
}
atlsConfig.Request = req
} else {
atlsConfig.RequestBuilder = func() (*atls.AuthenticatorRequest, error) {
return atls.NewRandomRequest(32)
}
}
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialNetwork, target := httpDialTarget(network, addr)
return atls.DialContext(ctx, dialNetwork, target, atlsConfig)
}
security = result.Security
} else {
conf := cfg.Config()
@@ -89,3 +120,13 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Secu
return transport, security, nil
}
func httpDialTarget(network, addr string) (string, string) {
if strings.HasPrefix(addr, "unix://") {
return "unix", strings.TrimPrefix(addr, "unix://")
}
if strings.HasPrefix(addr, "/") {
return "unix", addr
}
return network, addr
}
+57
View File
@@ -5,6 +5,7 @@ package http
import (
"net/http"
"os"
"testing"
"time"
@@ -210,6 +211,62 @@ func TestCreateTransport_ATLSError(t *testing.T) {
assert.Contains(t, err.Error(), "failed to stat attestation policy")
}
func TestCreateTransport_ATLSCustomRequestContext(t *testing.T) {
policyFile, err := os.CreateTemp("", "attestation_policy.json")
assert.NoError(t, err)
_, err = policyFile.WriteString("{}")
assert.NoError(t, err)
assert.NoError(t, policyFile.Close())
t.Cleanup(func() {
_ = os.Remove(policyFile.Name())
})
config := &clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "https://agent.example.com",
Timeout: 60 * time.Second,
},
AttestationPolicy: policyFile.Name(),
AttestedTLS: true,
AttestationRequestContextHex: "01020304",
}
transport, security, err := createTransport(config)
assert.NoError(t, err)
assert.NotNil(t, transport)
assert.Equal(t, tls.WithATLS, security)
assert.NotNil(t, transport.DialTLSContext)
}
func TestCreateTransport_ATLSInvalidRequestContext(t *testing.T) {
policyFile, err := os.CreateTemp("", "attestation_policy.json")
assert.NoError(t, err)
_, err = policyFile.WriteString("{}")
assert.NoError(t, err)
assert.NoError(t, policyFile.Close())
t.Cleanup(func() {
_ = os.Remove(policyFile.Name())
})
config := &clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "https://agent.example.com",
Timeout: 60 * time.Second,
},
AttestationPolicy: policyFile.Name(),
AttestedTLS: true,
AttestationRequestContextHex: "xyz",
}
transport, security, err := createTransport(config)
assert.Error(t, err)
assert.Nil(t, transport)
assert.Equal(t, tls.WithoutTLS, security)
assert.Contains(t, err.Error(), "invalid attestation request context")
}
func TestCreateTransport_BasicTLSError(t *testing.T) {
config := clients.StandardClientConfig{
URL: "https://example.com",
+58 -28
View File
@@ -11,6 +11,7 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"os"
"sync"
"github.com/ultravioletrs/cocos/pkg/atls"
@@ -18,6 +19,8 @@ import (
"golang.org/x/net/http2/h2c"
)
const unix = "unix"
// ProxyConfig contains configuration for starting a proxy instance.
type ProxyConfig struct {
Port string
@@ -81,12 +84,12 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
var rp *httputil.ReverseProxy
// Check if backend is Unix socket or TCP
if p.backendURL.Scheme == "unix" {
if p.backendURL.Scheme == unix {
// For Unix socket backend, we need to manually configure the reverse proxy
// because NewSingleHostReverseProxy doesn't support unix:// scheme
targetURL := &url.URL{
Scheme: "http",
Host: "unix",
Host: unix,
}
rp = httputil.NewSingleHostReverseProxy(targetURL)
@@ -96,7 +99,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
originalDirector(req)
// Set the URL to point to the backend service
req.URL.Scheme = "http"
req.URL.Host = "unix"
req.URL.Host = unix
}
// Configure Transport for Unix socket with HTTP/2
@@ -105,7 +108,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
var d net.Dialer
// Use Unix socket path from URL
return d.DialContext(ctx, "unix", p.backendURL.Path)
return d.DialContext(ctx, unix, p.backendURL.Path)
},
}
} else {
@@ -130,6 +133,8 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
// Configure TLS
var tlsConfig *tls.Config
var listener net.Listener
var err error
contextDesc := fmt.Sprintf("context %s", ctx.ID)
if ctx.Name != "" {
contextDesc = fmt.Sprintf("%s (%s)", ctx.Name, ctx.ID)
@@ -139,23 +144,31 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
if p.certProvider == nil {
return fmt.Errorf("attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running")
}
tlsConfig = &tls.Config{
GetCertificate: p.certProvider.GetCertificate,
ClientAuth: tls.NoClientCert,
NextProtos: []string{"h2", "http/1.1"},
tlsConfig, identity, mtls, err := atls.BuildServerTLSConfig(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to setup attested TLS: %w", err)
}
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
listener, err = p.attestedListener(addr, tlsConfig, identity)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
mtls, err := ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to configure certificate authorities: %w", err)
}
p.started = true
go func() {
serveErr := p.httpServer.Serve(listener)
if serveErr != nil && serveErr != http.ErrServerClosed {
p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", serveErr))
}
}()
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested mTLS for %s", addr, contextDesc))
} else {
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested TLS for %s", addr, contextDesc))
}
return nil
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
// Regular TLS
tlsSetup, err := SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
@@ -174,31 +187,48 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s without TLS for %s", addr, contextDesc))
}
if tlsConfig != nil {
tcpListener, listenErr := net.Listen("tcp", addr)
if listenErr != nil {
return fmt.Errorf("failed to listen: %w", listenErr)
}
listener = tls.NewListener(tcpListener, tlsConfig)
} else {
listener, err = net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
}
p.started = true
// Start server in goroutine
go func() {
var err error
if tlsConfig != nil {
ln, listenErr := net.Listen("tcp", addr)
if listenErr != nil {
p.logger.Error(fmt.Sprintf("failed to listen: %s", listenErr))
return
}
tlsLn := tls.NewListener(ln, tlsConfig)
err = p.httpServer.Serve(tlsLn)
} else {
err = p.httpServer.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", err))
serveErr := p.httpServer.Serve(listener)
if serveErr != nil && serveErr != http.ErrServerClosed {
p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", serveErr))
}
}()
return nil
}
func (p *proxyServer) attestedListener(addr string, tlsConfig *tls.Config, identity tls.Certificate) (net.Listener, error) {
network := "tcp"
address := addr
if len(addr) > 0 && addr[0] == '/' {
network = unix
address = addr
_ = os.Remove(address)
}
return atls.Listen(network, address, &atls.ServerConfig{
TLSConfig: tlsConfig,
Identity: identity,
BuildLeafExtensions: p.certProvider.BuildLeafExtensions,
})
}
// Stop stops the proxy server.
func (p *proxyServer) Stop() error {
p.mu.Lock()
+32 -1
View File
@@ -18,6 +18,7 @@ import (
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
@@ -113,6 +114,9 @@ func TestProxyStartWithoutPort(t *testing.T) {
ctx := ProxyContext{ID: "test-2"}
err := ps.Start(cfg, ctx)
if err != nil && strings.Contains(err.Error(), "address already in use") {
t.Skip("default ingress port 7002 is already in use")
}
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
time.Sleep(100 * time.Millisecond)
@@ -141,6 +145,33 @@ func TestProxyStartAlreadyStarted(t *testing.T) {
assert.Equal(t, "proxy server already started", err.Error())
}
func TestProxyStartReturnsListenerError(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil).(*proxyServer)
occupied, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer occupied.Close()
port := occupied.Addr().(*net.TCPAddr).Port
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
ctx := ProxyContext{ID: "test-bind-failure"}
err = ps.Start(cfg, ctx)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to listen")
assert.False(t, ps.started)
retry, retryErr := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, retryErr)
retryPort := retry.Addr().(*net.TCPAddr).Port
retry.Close()
err = ps.Start(ProxyConfig{Port: fmt.Sprintf("%d", retryPort)}, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
}
// TestProxyStartAfterStopped tests error when starting after stop.
func TestProxyStartAfterStopped(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
@@ -473,5 +504,5 @@ func TestProxyAttestedTLSInvalidCA(t *testing.T) {
err := ps.Start(cfg, ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to configure certificate authorities")
assert.Contains(t, err.Error(), "failed to setup attested TLS")
}
+5 -17
View File
@@ -4,15 +4,12 @@
package tls
import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"os"
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/attestation"
)
@@ -125,21 +122,12 @@ func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey strin
security = WithMATLS
}
nonce := make([]byte, 64)
if _, err := rand.Read(nonce); err != nil {
return nil, errors.Wrap(errors.New("failed to generate nonce"), err)
}
encoded := hex.EncodeToString(nonce)
sni := encoded + ".nonce"
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
RootCAs: rootCAs,
ServerName: sni,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return atls.NewCertificateVerifier(rootCAs).VerifyPeerCertificate(rawCerts, verifiedChains, nonce)
},
MinVersion: tls.VersionTLS13,
RootCAs: rootCAs,
}
if rootCAs == nil {
tlsConfig.InsecureSkipVerify = true
}
if clientCert != "" || clientKey != "" {
+6 -4
View File
@@ -6,6 +6,7 @@ package tls
import (
"crypto/rand"
"crypto/rsa"
stdtls "crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
@@ -296,10 +297,11 @@ func TestLoadATLSConfig(t *testing.T) {
assert.NotNil(t, result.Config)
// Verify TLS config properties
assert.True(t, result.Config.InsecureSkipVerify)
assert.NotNil(t, result.Config.VerifyPeerCertificate)
assert.NotEmpty(t, result.Config.ServerName)
assert.Contains(t, result.Config.ServerName, ".nonce")
assert.Equal(t, tt.serverCAFile == "", result.Config.InsecureSkipVerify)
assert.Equal(t, uint16(stdtls.VersionTLS13), result.Config.MinVersion)
if tt.serverCAFile != "" {
assert.NotNil(t, result.Config.RootCAs)
}
})
}
}