mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Post-handshake aTLS (#582)
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* initial post-handshake aTLS implementation * add header * rebased * remove grpc.go and http.go * fix authenticator issues * add freshness nonce --------- Co-authored-by: ultraviolet <cocosai@worker-52.local.pragmatic-it.com> Co-authored-by: ultraviolet <cocosai@k8s-master.local.pragmatic-it.com>
This commit is contained in:
committed by
GitHub
parent
42b05524c8
commit
80bf813c48
@@ -1,69 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"encoding/asn1"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultNotAfterYears = 1
|
||||
nonceLength = 64
|
||||
nonceSuffix = ".nonce"
|
||||
)
|
||||
|
||||
// Platform-specific OIDs for certificate extensions.
|
||||
var (
|
||||
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
|
||||
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
|
||||
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
|
||||
)
|
||||
|
||||
// CertificateSubject contains certificate subject information.
|
||||
type CertificateSubject struct {
|
||||
Organization string
|
||||
CommonName string
|
||||
Country string
|
||||
Province string
|
||||
Locality string
|
||||
StreetAddress string
|
||||
PostalCode string
|
||||
}
|
||||
|
||||
// DefaultCertificateSubject returns the default certificate subject for Ultraviolet.
|
||||
func DefaultCertificateSubject() CertificateSubject {
|
||||
return CertificateSubject{
|
||||
Organization: "Ultraviolet",
|
||||
CommonName: "Ultraviolet",
|
||||
Country: "Serbia",
|
||||
Province: "",
|
||||
Locality: "Belgrade",
|
||||
StreetAddress: "Bulevar Arsenija Carnojevica 103",
|
||||
PostalCode: "11000",
|
||||
}
|
||||
}
|
||||
|
||||
func extractNonceFromSNI(serverName string) ([]byte, error) {
|
||||
if len(serverName) < len(nonceSuffix) || !hasNonceSuffix(serverName) {
|
||||
return nil, fmt.Errorf("invalid server name: %s", serverName)
|
||||
}
|
||||
|
||||
nonceStr := serverName[:len(serverName)-len(nonceSuffix)]
|
||||
nonce, err := hex.DecodeString(nonceStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode nonce: %w", err)
|
||||
}
|
||||
|
||||
if len(nonce) != nonceLength {
|
||||
return nil, fmt.Errorf("invalid nonce length: expected %d bytes, got %d bytes", nonceLength, len(nonce))
|
||||
}
|
||||
|
||||
return nonce, nil
|
||||
}
|
||||
|
||||
func hasNonceSuffix(serverName string) bool {
|
||||
return len(serverName) >= len(nonceSuffix) &&
|
||||
serverName[len(serverName)-len(nonceSuffix):] == nonceSuffix
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,82 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
// AttestationProvider defines the interface for platform attestation operations.
|
||||
type AttestationProvider interface {
|
||||
Attest(pubKey []byte, nonce []byte) ([]byte, error)
|
||||
OID() asn1.ObjectIdentifier
|
||||
PlatformType() attestation.PlatformType
|
||||
}
|
||||
|
||||
// PlatformAttestationProvider handles platform attestation operations.
|
||||
type platformAttestationProvider struct {
|
||||
attClient attestation_client.Client
|
||||
oid asn1.ObjectIdentifier
|
||||
platformType attestation.PlatformType
|
||||
}
|
||||
|
||||
// NewAttestationProvider creates a new attestation provider for the given platform type.
|
||||
func NewAttestationProvider(attClient attestation_client.Client, platformType attestation.PlatformType) (AttestationProvider, error) {
|
||||
oid, err := OID(platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OID: %w", err)
|
||||
}
|
||||
|
||||
return &platformAttestationProvider{
|
||||
attClient: attClient,
|
||||
oid: oid,
|
||||
platformType: platformType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *platformAttestationProvider) Attest(pubKey []byte, nonce []byte) ([]byte, error) {
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
var reportData [64]byte
|
||||
copy(reportData[:], hashNonce[:])
|
||||
|
||||
var nonceArray [32]byte
|
||||
copy(nonceArray[:], hashNonce[:32])
|
||||
|
||||
// Get signed EAT token from attestation service
|
||||
// The attestation service maintains a persistent signing key and returns a pre-signed token
|
||||
eatToken, err := p.attClient.GetAttestation(context.Background(), reportData, nonceArray, p.platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attestation from service: %w", err)
|
||||
}
|
||||
|
||||
return eatToken, nil
|
||||
}
|
||||
|
||||
func (p *platformAttestationProvider) OID() asn1.ObjectIdentifier {
|
||||
return p.oid
|
||||
}
|
||||
|
||||
func (p *platformAttestationProvider) PlatformType() attestation.PlatformType {
|
||||
return p.platformType
|
||||
}
|
||||
|
||||
func OID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) {
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
return SNPvTPMOID, nil
|
||||
case attestation.Azure:
|
||||
return AzureOID, nil
|
||||
case attestation.TDX:
|
||||
return TDXOID, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
}
|
||||
@@ -1,190 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/certs"
|
||||
sdk "github.com/absmach/certs/sdk"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
)
|
||||
|
||||
// CertificateProvider defines the interface for providing TLS certificates.
|
||||
type CertificateProvider interface {
|
||||
GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
// AttestedCertificateProvider provides attested TLS certificates.
|
||||
type attestedCertificateProvider struct {
|
||||
attestationProvider AttestationProvider
|
||||
certsSDK sdk.SDK
|
||||
agentToken string
|
||||
subject CertificateSubject
|
||||
useCA bool
|
||||
cvmID string
|
||||
ttl time.Duration
|
||||
notAfterYears int
|
||||
}
|
||||
|
||||
// NewAttestedProvider creates a new attested certificate provider for self-signed certificates.
|
||||
func NewAttestedProvider(
|
||||
attestationProvider AttestationProvider,
|
||||
subject CertificateSubject,
|
||||
) CertificateProvider {
|
||||
return &attestedCertificateProvider{
|
||||
attestationProvider: attestationProvider,
|
||||
subject: subject,
|
||||
useCA: false,
|
||||
notAfterYears: defaultNotAfterYears,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAttestedCAProvider creates a new attested certificate provider for CA-signed certificates.
|
||||
func NewAttestedCAProvider(
|
||||
attestationProvider AttestationProvider,
|
||||
subject CertificateSubject,
|
||||
certsSDK sdk.SDK, cvmID, agentToken string,
|
||||
) CertificateProvider {
|
||||
return &attestedCertificateProvider{
|
||||
attestationProvider: attestationProvider,
|
||||
subject: subject,
|
||||
certsSDK: certsSDK,
|
||||
agentToken: agentToken,
|
||||
useCA: true,
|
||||
cvmID: cvmID,
|
||||
ttl: time.Hour * 24 * 365, // Default 1 year
|
||||
}
|
||||
}
|
||||
|
||||
// SetTTL sets the certificate TTL for CA-signed certificates.
|
||||
func (p *attestedCertificateProvider) SetTTL(ttl time.Duration) {
|
||||
p.ttl = ttl
|
||||
}
|
||||
|
||||
func (p *attestedCertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
nonce, err := extractNonceFromSNI(clientHello.ServerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract nonce: %w", err)
|
||||
}
|
||||
|
||||
attestationData, err := p.attestationProvider.Attest(pubKeyDER, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attestation: %w", err)
|
||||
}
|
||||
|
||||
extension := pkix.Extension{
|
||||
Id: p.attestationProvider.OID(),
|
||||
Value: attestationData,
|
||||
}
|
||||
|
||||
var certDERBytes []byte
|
||||
if p.useCA {
|
||||
certDERBytes, err = p.generateCASignedCertificate(clientHello.Context(), privateKey, extension)
|
||||
} else {
|
||||
certDERBytes, err = p.generateSelfSignedCertificate(privateKey, extension)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate certificate: %w", err)
|
||||
}
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{certDERBytes},
|
||||
PrivateKey: privateKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *attestedCertificateProvider) generateSelfSignedCertificate(privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) {
|
||||
certTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().Unix()),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{p.subject.Organization},
|
||||
Country: []string{p.subject.Country},
|
||||
Province: []string{p.subject.Province},
|
||||
Locality: []string{p.subject.Locality},
|
||||
StreetAddress: []string{p.subject.StreetAddress},
|
||||
PostalCode: []string{p.subject.PostalCode},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(p.notAfterYears, 0, 0),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
ExtraExtensions: []pkix.Extension{extension},
|
||||
}
|
||||
|
||||
return x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
|
||||
}
|
||||
|
||||
func (p *attestedCertificateProvider) generateCASignedCertificate(ctx context.Context, privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) {
|
||||
csrMetadata := certs.CSRMetadata{
|
||||
Organization: []string{p.subject.Organization},
|
||||
Country: []string{p.subject.Country},
|
||||
CommonName: p.subject.CommonName,
|
||||
Province: []string{p.subject.Province},
|
||||
Locality: []string{p.subject.Locality},
|
||||
StreetAddress: []string{p.subject.StreetAddress},
|
||||
PostalCode: []string{p.subject.PostalCode},
|
||||
ExtraExtensions: []pkix.Extension{extension},
|
||||
}
|
||||
|
||||
csr, sdkerr := p.certsSDK.CreateCSR(ctx, csrMetadata, privateKey)
|
||||
if sdkerr != nil {
|
||||
return nil, fmt.Errorf("failed to create CSR: %w", sdkerr)
|
||||
}
|
||||
|
||||
cert, err := p.certsSDK.IssueFromCSRInternal(ctx, p.cvmID, p.ttl.String(), string(csr.CSR), p.agentToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
|
||||
block, rest := pem.Decode([]byte(cleanCertificateString))
|
||||
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("failed to decode certificate PEM: unexpected remaining data")
|
||||
}
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode certificate PEM: no PEM block found")
|
||||
}
|
||||
|
||||
return block.Bytes, nil
|
||||
}
|
||||
|
||||
func NewProvider(attClient attestation_client.Client, platformType attestation.PlatformType, agentToken, cvmID string, certsSDK sdk.SDK) (CertificateProvider, error) {
|
||||
attestationProvider, err := NewAttestationProvider(attClient, platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create attestation provider: %w", err)
|
||||
}
|
||||
|
||||
subject := DefaultCertificateSubject()
|
||||
|
||||
if certsSDK != nil {
|
||||
return NewAttestedCAProvider(attestationProvider, subject, certsSDK, cvmID, agentToken), nil
|
||||
}
|
||||
|
||||
return NewAttestedProvider(attestationProvider, subject), nil
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/veraison/corim/corim"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
type CertificateVerifier interface {
|
||||
VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte) error
|
||||
}
|
||||
|
||||
// CertificateVerifier handles certificate verification operations.
|
||||
type certificateVerifier struct {
|
||||
rootCAs *x509.CertPool
|
||||
verifierProvider func(attestation.PlatformType) (attestation.Verifier, error)
|
||||
}
|
||||
|
||||
func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier {
|
||||
return &certificateVerifier{
|
||||
rootCAs: rootCAs,
|
||||
verifierProvider: platformVerifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error {
|
||||
slog.Debug("Starting peer certificate verification for aTLS")
|
||||
if len(rawCerts) == 0 {
|
||||
err := fmt.Errorf("no certificates provided")
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse x509 certificate: %w", err)
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
slog.Debug("Successfully parsed peer x509 certificate", "subject", cert.Subject.String())
|
||||
|
||||
if err := v.verifyCertificateSignature(cert); err != nil {
|
||||
err = fmt.Errorf("certificate signature verification failed: %w", err)
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
slog.Debug("Successfully verified peer certificate signature")
|
||||
|
||||
err = v.verifyAttestationExtension(cert, nonce)
|
||||
if err != nil {
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
slog.Debug("Successfully verified aTLS attestation extension")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) error {
|
||||
rootCAs := v.rootCAs
|
||||
if rootCAs == nil {
|
||||
rootCAs = x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: rootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
}
|
||||
|
||||
_, err := cert.Verify(opts)
|
||||
return err
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, nonce []byte) error {
|
||||
for _, ext := range cert.Extensions {
|
||||
if platformType, err := platformTypeFromOID(ext.Id); err == nil {
|
||||
slog.Debug("Found attestation extension in peer certificate", "platform_type", platformType)
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
return v.verifyCertificateExtension(ext.Value, pubKeyDER, nonce, platformType)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("attestation extension not found in certificate")
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, platformType attestation.PlatformType) error {
|
||||
// Decode EAT token from certificate extension
|
||||
// Note: We don't have the public key for verification here, so we decode without verification
|
||||
// The signature was created by the attester, and we trust the TEE hardware verification
|
||||
claims, err := eat.DecodeCBOR(extension, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode EAT token: %w", err)
|
||||
}
|
||||
|
||||
// Verify nonce matches
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
// The attestation provider truncates the 64-byte hash to 32 bytes for the EAT token nonce claim
|
||||
// This matches the Attestation Service API and standard cryptographic nonce sizes.
|
||||
expectedNonce := hashNonce[:32]
|
||||
|
||||
// Compare nonces (EAT nonce should match our computed nonce)
|
||||
if len(claims.Nonce) != len(expectedNonce) {
|
||||
err := fmt.Errorf("nonce length mismatch: expected %d, got %d", len(expectedNonce), len(claims.Nonce))
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
nonceMatch := true
|
||||
for i := range claims.Nonce {
|
||||
if claims.Nonce[i] != expectedNonce[i] {
|
||||
nonceMatch = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !nonceMatch {
|
||||
err := fmt.Errorf("nonce mismatch in EAT token")
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// Get platform verifier
|
||||
verifier, err := v.verifierProvider(platformType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get platform verifier: %w", err)
|
||||
}
|
||||
|
||||
// Load and parse CoRIM
|
||||
if attestation.AttestationPolicyPath == "" {
|
||||
return fmt.Errorf("attestation policy path is not set")
|
||||
}
|
||||
|
||||
corimBytes, err := os.ReadFile(attestation.AttestationPolicyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read CoRIM file: %w", err)
|
||||
}
|
||||
|
||||
// Try extracting from COSE Sign1 first
|
||||
var unsignedCorim *corim.UnsignedCorim
|
||||
|
||||
var sc corim.SignedCorim
|
||||
if err := sc.FromCOSE(corimBytes); err == nil {
|
||||
// It's a COSE Sign1 message
|
||||
unsignedCorim = &sc.UnsignedCorim
|
||||
} else {
|
||||
// Try parsing as unsigned CoRIM directly
|
||||
var uc corim.UnsignedCorim
|
||||
if err := uc.FromCBOR(corimBytes); err != nil {
|
||||
return fmt.Errorf("failed to parse CoRIM (tried both signed and unsigned): %w", err)
|
||||
}
|
||||
unsignedCorim = &uc
|
||||
}
|
||||
|
||||
// Re-wrap in Corim struct expected by Verifiers
|
||||
// Since verifiers expect the struct from the removed internal package,
|
||||
// we need to update verifiers to accept veraison/corim types
|
||||
// For now, we pass the unsignedCorim directly
|
||||
if err = verifier.VerifyWithCoRIM(claims.RawReport, unsignedCorim); err != nil {
|
||||
return fmt.Errorf("failed to verify attestation with CoRIM: %w", err)
|
||||
}
|
||||
|
||||
slog.Debug("CoRIM verification passed for aTLS peer certificate")
|
||||
return nil
|
||||
}
|
||||
|
||||
func platformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
|
||||
switch {
|
||||
case oid.Equal(SNPvTPMOID):
|
||||
return attestation.SNPvTPM, nil
|
||||
case oid.Equal(AzureOID):
|
||||
return attestation.Azure, nil
|
||||
case oid.Equal(TDXOID):
|
||||
return attestation.TDX, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported OID: %v", oid)
|
||||
}
|
||||
}
|
||||
|
||||
func platformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) {
|
||||
var verifier attestation.Verifier
|
||||
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
verifier = vtpm.NewVerifier(nil)
|
||||
case attestation.Azure:
|
||||
verifier = azure.NewVerifier(nil)
|
||||
case attestation.TDX:
|
||||
verifier = tdx.NewVerifier()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
|
||||
return verifier, nil
|
||||
}
|
||||
@@ -1,338 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/veraison/corim/corim"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
type mockVerifier struct {
|
||||
verifyWithCoRIMFunc func(report []byte, manifest *corim.UnsignedCorim) error
|
||||
}
|
||||
|
||||
func (m *mockVerifier) VerifyWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error {
|
||||
if m.verifyWithCoRIMFunc != nil {
|
||||
return m.verifyWithCoRIMFunc(report, manifest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate_Success(t *testing.T) {
|
||||
// Setup keys and cert templates
|
||||
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "Test CA"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
caCert, err := x509.ParseCertificate(caCertDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(caCert)
|
||||
|
||||
// Create verifier with mock platform verifier
|
||||
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
|
||||
verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) {
|
||||
return &mockVerifier{
|
||||
verifyWithCoRIMFunc: func(report []byte, manifest *corim.UnsignedCorim) error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
peerKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Prepare EAT Claims
|
||||
nonce := []byte("test-nonce")
|
||||
peerPubKeyDER, err := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
teeNonce := append(peerPubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
claims := eat.EATClaims{
|
||||
Nonce: hashNonce[:32],
|
||||
RawReport: []byte("mock-report"),
|
||||
}
|
||||
eatBytes, err := cbor.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create Peer Cert with EAT extension
|
||||
peerTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{CommonName: "Test Peer"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: SNPvTPMOID, // Use SNPvTPMOID as default testing OID
|
||||
Value: eatBytes,
|
||||
},
|
||||
},
|
||||
}
|
||||
peerCertDER, err := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create dummy CoRIM file
|
||||
tempDir, err := os.MkdirTemp("", "policy")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
c := corim.NewUnsignedCorim()
|
||||
c.SetID("cocos-test-id")
|
||||
corimBytes, err := c.ToCBOR()
|
||||
require.NoError(t, err)
|
||||
|
||||
policyPath := filepath.Join(tempDir, "attestation_policy.json")
|
||||
err = os.WriteFile(policyPath, corimBytes, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
oldPolicyPath := attestation.AttestationPolicyPath
|
||||
attestation.AttestationPolicyPath = policyPath
|
||||
t.Cleanup(func() {
|
||||
attestation.AttestationPolicyPath = oldPolicyPath
|
||||
})
|
||||
|
||||
err = verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate_AzureSuccess(t *testing.T) {
|
||||
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "Test CA"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
caCert, _ := x509.ParseCertificate(caCertDER)
|
||||
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(caCert)
|
||||
|
||||
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
|
||||
verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) {
|
||||
return &mockVerifier{}, nil
|
||||
}
|
||||
|
||||
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
nonce := []byte("test-nonce")
|
||||
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
|
||||
teeNonce := append(peerPubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
|
||||
eatBytes, _ := cbor.Marshal(claims)
|
||||
|
||||
peerTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{CommonName: "Azure Peer"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
ExtraExtensions: []pkix.Extension{{Id: AzureOID, Value: eatBytes}},
|
||||
}
|
||||
peerCertDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
c := corim.NewUnsignedCorim()
|
||||
c.SetID("cocos-test-id")
|
||||
corimBytes, _ := c.ToCBOR()
|
||||
policyPath := filepath.Join(tempDir, "policy.cbor")
|
||||
_ = os.WriteFile(policyPath, corimBytes, 0o644)
|
||||
|
||||
oldPolicyPath := attestation.AttestationPolicyPath
|
||||
attestation.AttestationPolicyPath = policyPath
|
||||
t.Cleanup(func() { attestation.AttestationPolicyPath = oldPolicyPath })
|
||||
|
||||
err := verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate_TDXSuccess(t *testing.T) {
|
||||
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "Test CA"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
caCert, _ := x509.ParseCertificate(caCertDER)
|
||||
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(caCert)
|
||||
|
||||
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
|
||||
verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) {
|
||||
return &mockVerifier{}, nil
|
||||
}
|
||||
|
||||
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
nonce := []byte("test-nonce")
|
||||
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
|
||||
teeNonce := append(peerPubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
|
||||
eatBytes, _ := cbor.Marshal(claims)
|
||||
|
||||
peerTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(3),
|
||||
Subject: pkix.Name{CommonName: "TDX Peer"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
ExtraExtensions: []pkix.Extension{{Id: TDXOID, Value: eatBytes}},
|
||||
}
|
||||
peerCertDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
c := corim.NewUnsignedCorim()
|
||||
c.SetID("cocos-test-id")
|
||||
corimBytes, _ := c.ToCBOR()
|
||||
policyPath := filepath.Join(tempDir, "policy.cbor")
|
||||
_ = os.WriteFile(policyPath, corimBytes, 0o644)
|
||||
|
||||
oldPolicyPath := attestation.AttestationPolicyPath
|
||||
attestation.AttestationPolicyPath = policyPath
|
||||
t.Cleanup(func() { attestation.AttestationPolicyPath = oldPolicyPath })
|
||||
|
||||
err := verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate_Failures_More(t *testing.T) {
|
||||
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "Test CA"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
caCert, _ := x509.ParseCertificate(caCertDER)
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(caCert)
|
||||
|
||||
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
|
||||
|
||||
// Case 1: Invalid OID
|
||||
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
peerTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(4),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
ExtraExtensions: []pkix.Extension{{Id: []int{1, 2, 3}, Value: []byte("val")}},
|
||||
}
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
err := verifier.VerifyPeerCertificate([][]byte{certDER}, nil, []byte("nonce"))
|
||||
assert.ErrorContains(t, err, "attestation extension not found")
|
||||
|
||||
// Case 2: Policy path not set
|
||||
attestation.AttestationPolicyPath = ""
|
||||
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
|
||||
nonce := []byte("nonce")
|
||||
teeNonce := append(peerPubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
|
||||
eatBytes, _ := cbor.Marshal(claims)
|
||||
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
|
||||
certDERWithExt, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
|
||||
err = verifier.VerifyPeerCertificate([][]byte{certDERWithExt}, nil, nonce)
|
||||
assert.ErrorContains(t, err, "attestation policy path is not set")
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate_Failures_Ext(t *testing.T) {
|
||||
rootCAs := x509.NewCertPool()
|
||||
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
|
||||
|
||||
// Case 1: No certificates
|
||||
err := verifier.VerifyPeerCertificate([][]byte{}, nil, []byte("nonce"))
|
||||
assert.ErrorContains(t, err, "no certificates provided")
|
||||
|
||||
// Case 2: Nonce length mismatch
|
||||
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
nonce := []byte("nonce")
|
||||
claims := eat.EATClaims{Nonce: []byte("short"), RawReport: []byte("rep")}
|
||||
eatBytes, _ := cbor.Marshal(claims)
|
||||
|
||||
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "CA"},
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
}
|
||||
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
caCert, _ := x509.ParseCertificate(caCertDER)
|
||||
rootCAs.AddCert(caCert)
|
||||
|
||||
peerTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(5),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
ExtraExtensions: []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}},
|
||||
}
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce)
|
||||
assert.ErrorContains(t, err, "nonce length mismatch")
|
||||
|
||||
// Case 3: Nonce mismatch
|
||||
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
|
||||
wrongTeeNonce := append(peerPubKeyDER, []byte("wrong-nonce")...)
|
||||
wrongHashNonce := sha3.Sum512(wrongTeeNonce)
|
||||
claims.Nonce = wrongHashNonce[:32]
|
||||
eatBytes, _ = cbor.Marshal(claims)
|
||||
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
|
||||
certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce)
|
||||
assert.ErrorContains(t, err, "nonce mismatch in EAT token")
|
||||
|
||||
// Case 4: Invalid EAT (CBOR decoder failure)
|
||||
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: []byte("invalid-cbor")}}
|
||||
certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce)
|
||||
assert.ErrorContains(t, err, "failed to decode EAT token")
|
||||
}
|
||||
@@ -0,0 +1,370 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTruncated = errors.New("ea: truncated input")
|
||||
ErrInvalidLength = errors.New("ea: invalid length")
|
||||
ErrUnsupportedHandshakeType = errors.New("ea: unsupported handshake type")
|
||||
ErrNotTLS13 = errors.New("ea: not TLS 1.3")
|
||||
ErrUnknownCipherSuite = errors.New("ea: unknown cipher suite")
|
||||
ErrContextReuse = errors.New("ea: certificate_request_context already used")
|
||||
ErrInvalidRole = errors.New("ea: invalid authenticator role")
|
||||
|
||||
ErrUnsupportedSignatureScheme = errors.New("ea: unsupported signature scheme")
|
||||
ErrSignatureMismatch = errors.New("ea: CertificateVerify signature mismatch")
|
||||
ErrFinishedMismatch = errors.New("ea: Finished MAC mismatch")
|
||||
ErrContextMismatch = errors.New("ea: certificate_request_context mismatch")
|
||||
ErrBadRequest = errors.New("ea: bad authenticator request")
|
||||
)
|
||||
|
||||
type ValidationResult struct {
|
||||
Context []byte
|
||||
Chain []*x509.Certificate
|
||||
CMWAttestation []byte
|
||||
Attestation *eaattestation.VerifiedPayload
|
||||
Empty bool
|
||||
}
|
||||
|
||||
func CreateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
|
||||
return createAuthenticator(nil, st, role, req, nil, identity, leafEntryExtensions)
|
||||
}
|
||||
|
||||
func CreateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
|
||||
return createAuthenticator(nil, st, role, req, policy, identity, leafEntryExtensions)
|
||||
}
|
||||
|
||||
func createAuthenticator(session *Session, st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
|
||||
if st.Version != tls.VersionTLS13 {
|
||||
return nil, ErrNotTLS13
|
||||
}
|
||||
if req == nil && role != RoleServer {
|
||||
return nil, ErrInvalidRole
|
||||
}
|
||||
if err := validateCertificateEntryExtensions(leafEntryExtensions, req, policy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
emptyAuthenticator := len(identity.Certificate) == 0 && identity.PrivateKey == nil
|
||||
if !emptyAuthenticator && (len(identity.Certificate) == 0 || identity.PrivateKey == nil) {
|
||||
return nil, ErrBadRequest
|
||||
}
|
||||
|
||||
var reqBytes []byte
|
||||
var offered []uint16
|
||||
var ctx []byte
|
||||
|
||||
if req != nil {
|
||||
var err error
|
||||
reqBytes, err = req.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = append([]byte(nil), req.Context...)
|
||||
if schemes, ok := req.SignatureSchemes(); ok {
|
||||
offered = schemes
|
||||
} else {
|
||||
return nil, fmt.Errorf("%w: missing signature_algorithms", ErrBadRequest)
|
||||
}
|
||||
} else {
|
||||
c, err := NewRandomContext(32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = c
|
||||
}
|
||||
|
||||
hsCtx, h, err := ExportHandshakeContext(st, role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fk, _, err := ExportFinishedKey(st, role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if emptyAuthenticator {
|
||||
if req == nil {
|
||||
return nil, ErrBadRequest
|
||||
}
|
||||
certBytes, err := (CertificateMessage{Context: ctx}).Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
th := hashConcat(h, hsCtx, reqBytes, certBytes)
|
||||
verifyData := hmacSum(h, fk, th)
|
||||
finBytes, err := (FinishedMessage{VerifyData: verifyData}).Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := session.MarkContextUsed(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return finBytes, nil
|
||||
}
|
||||
|
||||
scheme, err := chooseSignatureScheme(identity.PrivateKey, offered)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req == nil && !policyPermitsSignatureScheme(policy, scheme) {
|
||||
return nil, ErrUnsupportedSignatureScheme
|
||||
}
|
||||
|
||||
entries := make([]CertificateEntry, 0, len(identity.Certificate))
|
||||
for i, der := range identity.Certificate {
|
||||
exts := []Extension(nil)
|
||||
if i == 0 && len(leafEntryExtensions) > 0 {
|
||||
exts = leafEntryExtensions
|
||||
}
|
||||
entries = append(entries, CertificateEntry{CertDER: der, Extensions: exts})
|
||||
}
|
||||
certBytes, err := (CertificateMessage{Context: ctx, Entries: entries}).Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
th1 := hashConcat(h, hsCtx, reqBytes, certBytes)
|
||||
sig, err := signCertVerify(identity.PrivateKey, scheme, th1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cvBytes, err := (CertificateVerifyMessage{Algorithm: scheme, Signature: sig}).Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes)
|
||||
verifyData := hmacSum(h, fk, th2)
|
||||
finBytes, err := (FinishedMessage{VerifyData: verifyData}).Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := session.MarkContextUsed(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(append(certBytes, cvBytes...), finBytes...), nil
|
||||
}
|
||||
|
||||
func ValidateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
|
||||
return validateAuthenticator(nil, st, role, req, nil, nil, authBytes, verifyOpts)
|
||||
}
|
||||
|
||||
func ValidateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
|
||||
return validateAuthenticator(nil, st, role, req, policy, nil, authBytes, verifyOpts)
|
||||
}
|
||||
|
||||
func ValidateAuthenticatorWithAttestation(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) {
|
||||
return validateAuthenticator(nil, st, role, req, nil, &attPolicy, authBytes, verifyOpts)
|
||||
}
|
||||
|
||||
func ValidateAuthenticatorWithPolicies(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) {
|
||||
return validateAuthenticator(nil, st, role, req, policy, &attPolicy, authBytes, verifyOpts)
|
||||
}
|
||||
|
||||
func validateAuthenticator(session *Session, st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, attPolicy *eaattestation.VerificationPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
|
||||
if st.Version != tls.VersionTLS13 {
|
||||
return nil, ErrNotTLS13
|
||||
}
|
||||
|
||||
hsCtx, h, err := ExportHandshakeContext(st, role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fk, _, err := ExportFinishedKey(st, role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reqBytes []byte
|
||||
var offered []uint16
|
||||
var reqCtx []byte
|
||||
if req != nil {
|
||||
reqCtx = req.Context
|
||||
reqBytes, err = req.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if schemes, ok := req.SignatureSchemes(); ok {
|
||||
offered = schemes
|
||||
} else {
|
||||
return nil, fmt.Errorf("%w: missing signature_algorithms", ErrBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
firstHm, rest, err := UnmarshalHandshakeMessage(authBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if firstHm.Type == HandshakeTypeFinished {
|
||||
if req == nil || len(rest) != 0 {
|
||||
return nil, ErrUnsupportedHandshakeType
|
||||
}
|
||||
finBytes, _ := MarshalHandshakeMessage(firstHm)
|
||||
finMsg, _, err := UnmarshalFinishedMessage(finBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certBytes, err := (CertificateMessage{Context: reqCtx}).Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
th := hashConcat(h, hsCtx, reqBytes, certBytes)
|
||||
expectedFin := hmacSum(h, fk, th)
|
||||
if !constantTimeEqual(expectedFin, finMsg.VerifyData) {
|
||||
return nil, ErrFinishedMismatch
|
||||
}
|
||||
if err := session.MarkContextUsed(reqCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ValidationResult{
|
||||
Context: append([]byte(nil), reqCtx...),
|
||||
Empty: true,
|
||||
}, nil
|
||||
}
|
||||
if firstHm.Type != HandshakeTypeCertificate {
|
||||
return nil, ErrUnsupportedHandshakeType
|
||||
}
|
||||
certHm := firstHm
|
||||
certBytes, _ := MarshalHandshakeMessage(certHm)
|
||||
|
||||
cvHm, rest, err := UnmarshalHandshakeMessage(rest)
|
||||
if err != nil || cvHm.Type != HandshakeTypeCertificateVerify {
|
||||
return nil, ErrUnsupportedHandshakeType
|
||||
}
|
||||
cvBytes, _ := MarshalHandshakeMessage(cvHm)
|
||||
|
||||
finHm, rest, err := UnmarshalHandshakeMessage(rest)
|
||||
if err != nil || finHm.Type != HandshakeTypeFinished || len(rest) != 0 {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
finBytes, _ := MarshalHandshakeMessage(finHm)
|
||||
|
||||
certMsg, _, err := UnmarshalCertificateMessage(certBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cvMsg, _, err := UnmarshalCertificateVerifyMessage(cvBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
finMsg, _, err := UnmarshalFinishedMessage(finBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if req != nil && !bytes.Equal(certMsg.Context, reqCtx) {
|
||||
return nil, ErrContextMismatch
|
||||
}
|
||||
if len(certMsg.Entries) == 0 {
|
||||
// Empty authenticators are encoded as Finished-only. A Certificate
|
||||
// message with zero entries followed by CertificateVerify/Finished is
|
||||
// malformed and must not be accepted as an empty authenticator.
|
||||
return nil, ErrUnsupportedHandshakeType
|
||||
}
|
||||
if err := ValidateCMWAttestationPlacement(certMsg.Entries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, entry := range certMsg.Entries {
|
||||
if err := validateCertificateEntryExtensions(entry.Extensions, req, policy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
extracted, present, err := ExtractCMWAttestationFromExtensions(certMsg.Entries[0].Extensions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if present && req != nil && !RequestPermitsCertificateExtension(req, CMWAttestationExtensionType) {
|
||||
return nil, ErrBadRequest
|
||||
}
|
||||
if present && req == nil && !PolicyPermitsCertificateExtension(policy, CMWAttestationExtensionType) {
|
||||
return nil, ErrBadRequest
|
||||
}
|
||||
|
||||
chain := make([]*x509.Certificate, 0, len(certMsg.Entries))
|
||||
for _, e := range certMsg.Entries {
|
||||
c, err := x509.ParseCertificate(e.CertDER)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chain = append(chain, c)
|
||||
}
|
||||
leaf := chain[0]
|
||||
|
||||
if req != nil {
|
||||
ok := false
|
||||
for _, s := range offered {
|
||||
if s == cvMsg.Algorithm {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedSignatureScheme
|
||||
}
|
||||
}
|
||||
|
||||
th1 := hashConcat(h, hsCtx, reqBytes, certBytes)
|
||||
if err := verifyCertVerify(leaf.PublicKey, cvMsg.Algorithm, th1, cvMsg.Signature); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes)
|
||||
expectedFin := hmacSum(h, fk, th2)
|
||||
if !constantTimeEqual(expectedFin, finMsg.VerifyData) {
|
||||
return nil, ErrFinishedMismatch
|
||||
}
|
||||
|
||||
if verifyOpts != nil {
|
||||
opts := *verifyOpts
|
||||
if opts.Intermediates == nil {
|
||||
opts.Intermediates = x509.NewCertPool()
|
||||
}
|
||||
for _, ic := range chain[1:] {
|
||||
opts.Intermediates.AddCert(ic)
|
||||
}
|
||||
if _, err := leaf.Verify(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := session.MarkContextUsed(certMsg.Context); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &ValidationResult{
|
||||
Context: append([]byte(nil), certMsg.Context...),
|
||||
Chain: chain,
|
||||
}
|
||||
if present {
|
||||
res.CMWAttestation = extracted
|
||||
parsed, err := eaattestation.ParsePayload(extracted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var verifierPolicy eaattestation.VerificationPolicy
|
||||
// A nil attestation policy is intentional: VerifyPayload then fails closed
|
||||
// for any payload that carries evidence or attestation results without
|
||||
// explicit verifiers being configured.
|
||||
if attPolicy != nil {
|
||||
verifierPolicy = *attPolicy
|
||||
}
|
||||
verified, err := eaattestation.VerifyPayload(st, eaattestation.ExporterLabelAttestation, certMsg.Context, leaf, parsed, verifierPolicy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.Attestation = verified
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
@@ -0,0 +1,537 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
attestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
|
||||
)
|
||||
|
||||
func selfSignedCert(t *testing.T) tls.Certificate {
|
||||
t.Helper()
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "ea-phase3"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: priv}
|
||||
}
|
||||
|
||||
func tlsPair(t *testing.T, cert tls.Certificate) (srv, cli *tls.Conn) {
|
||||
t.Helper()
|
||||
srvConf := &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13}
|
||||
cliConf := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13}
|
||||
a, b := net.Pipe()
|
||||
srv = tls.Server(a, srvConf)
|
||||
cli = tls.Client(b, cliConf)
|
||||
errCh := make(chan error, 2)
|
||||
go func() { errCh <- srv.Handshake() }()
|
||||
go func() { errCh <- cli.Handshake() }()
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("handshake: %v", err)
|
||||
}
|
||||
}
|
||||
return srv, cli
|
||||
}
|
||||
|
||||
type acceptEvidenceVerifier struct{}
|
||||
|
||||
func (acceptEvidenceVerifier) VerifyEvidence(evidence []byte) error { return nil }
|
||||
|
||||
func TestDummyAttestationRoundTrip(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
ctx, _ := NewRandomContext(16)
|
||||
req := &AuthenticatorRequest{
|
||||
Type: HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
|
||||
CMWAttestationOfferExtension(),
|
||||
},
|
||||
}
|
||||
|
||||
leaf, _ := x509.ParseCertificate(cert.Certificate[0])
|
||||
srvState := srv.ConnectionState()
|
||||
_, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
payloadBytes, err := attestation.MarshalPayload(attestation.Payload{
|
||||
Version: 1,
|
||||
Evidence: []byte("dummy-attestation-report"),
|
||||
MediaType: "application/eat+cwt",
|
||||
Binder: attestation.AttestationBinder{
|
||||
ExporterLabel: attestation.ExporterLabelAttestation,
|
||||
AIKPubHash: aikPubHash,
|
||||
Binding: binding,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ext, err := CMWAttestationDataExtension(payloadBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
auth, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cliState := cli.ConnectionState()
|
||||
roots := x509.NewCertPool()
|
||||
roots.AddCert(leaf)
|
||||
|
||||
res, err := ValidateAuthenticatorWithAttestation(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}, attestation.VerificationPolicy{
|
||||
EvidenceVerifier: acceptEvidenceVerifier{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(res.CMWAttestation, payloadBytes) {
|
||||
t.Fatalf("cmw mismatch")
|
||||
}
|
||||
if res.Attestation == nil || !res.Attestation.BindingVerified || !res.Attestation.EvidenceVerified {
|
||||
t.Fatalf("expected verified attestation result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectIfNotOffered(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
ctx, _ := NewRandomContext(8)
|
||||
req := &AuthenticatorRequest{
|
||||
Type: HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
|
||||
},
|
||||
}
|
||||
|
||||
ext, _ := CMWAttestationDataExtension([]byte("dummy"))
|
||||
srvState := srv.ConnectionState()
|
||||
if _, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext}); err == nil {
|
||||
t.Fatalf("expected error when cmw_attestation not offered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestationFailsClosedWithoutVerifier(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
ctx, _ := NewRandomContext(16)
|
||||
req := &AuthenticatorRequest{
|
||||
Type: HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
|
||||
CMWAttestationOfferExtension(),
|
||||
},
|
||||
}
|
||||
|
||||
leaf, _ := x509.ParseCertificate(cert.Certificate[0])
|
||||
srvState := srv.ConnectionState()
|
||||
_, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
payloadBytes, err := attestation.MarshalPayload(attestation.Payload{
|
||||
Version: 1,
|
||||
Evidence: []byte("dummy-attestation-report"),
|
||||
Binder: attestation.AttestationBinder{
|
||||
ExporterLabel: attestation.ExporterLabelAttestation,
|
||||
AIKPubHash: aikPubHash,
|
||||
Binding: binding,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ext, err := CMWAttestationDataExtension(payloadBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
auth, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cliState := cli.ConnectionState()
|
||||
if _, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil); err != attestation.ErrEvidenceVerificationMissing {
|
||||
t.Fatalf("got %v, want %v", err, attestation.ErrEvidenceVerificationMissing)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectCMWAttestationOnIntermediateEntry(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
ctx, _ := NewRandomContext(16)
|
||||
req := &AuthenticatorRequest{
|
||||
Type: HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
|
||||
CMWAttestationOfferExtension(),
|
||||
},
|
||||
}
|
||||
|
||||
leaf, _ := x509.ParseCertificate(cert.Certificate[0])
|
||||
srvState := srv.ConnectionState()
|
||||
hsCtx, h, err := ExportHandshakeContext(&srvState, RoleServer)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fk, _, err := ExportFinishedKey(&srvState, RoleServer)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
reqBytes, err := req.Marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
payloadBytes, err := attestation.MarshalPayload(attestation.Payload{
|
||||
Version: 1,
|
||||
Evidence: []byte("dummy-attestation-report"),
|
||||
MediaType: "application/eat+cwt",
|
||||
Binder: attestation.AttestationBinder{
|
||||
ExporterLabel: attestation.ExporterLabelAttestation,
|
||||
AIKPubHash: aikPubHash,
|
||||
Binding: binding,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ext, err := CMWAttestationDataExtension(payloadBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certBytes, err := (CertificateMessage{
|
||||
Context: ctx,
|
||||
Entries: []CertificateEntry{
|
||||
{CertDER: cert.Certificate[0]},
|
||||
{CertDER: cert.Certificate[0], Extensions: []Extension{ext}},
|
||||
},
|
||||
}).Marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
th1 := hashConcat(h, hsCtx, reqBytes, certBytes)
|
||||
sig, err := signCertVerify(cert.PrivateKey, uint16(tls.ECDSAWithP256AndSHA256), th1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cvBytes, err := (CertificateVerifyMessage{
|
||||
Algorithm: uint16(tls.ECDSAWithP256AndSHA256),
|
||||
Signature: sig,
|
||||
}).Marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes)
|
||||
finBytes, err := (FinishedMessage{VerifyData: hmacSum(h, fk, th2)}).Marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
auth := append(append(certBytes, cvBytes...), finBytes...)
|
||||
|
||||
cliState := cli.ConnectionState()
|
||||
if _, err := ValidateAuthenticatorWithAttestation(&cliState, RoleServer, req, auth, nil, attestation.VerificationPolicy{
|
||||
EvidenceVerifier: acceptEvidenceVerifier{},
|
||||
}); err != ErrBadRequest {
|
||||
t.Fatalf("got %v, want %v", err, ErrBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRejectsContextReuse(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
ctx, _ := NewRandomContext(12)
|
||||
req := &AuthenticatorRequest{
|
||||
Type: HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
|
||||
},
|
||||
}
|
||||
|
||||
createSession := NewSession()
|
||||
srvState := srv.ConnectionState()
|
||||
auth, err := createSession.CreateAuthenticator(&srvState, RoleServer, req, cert, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := createSession.CreateAuthenticator(&srvState, RoleServer, req, cert, nil); err != ErrContextReuse {
|
||||
t.Fatalf("got %v, want %v", err, ErrContextReuse)
|
||||
}
|
||||
|
||||
validateSession := NewSession()
|
||||
cliState := cli.ConnectionState()
|
||||
roots := x509.NewCertPool()
|
||||
leaf, _ := x509.ParseCertificate(cert.Certificate[0])
|
||||
roots.AddCert(leaf)
|
||||
|
||||
if _, err := validateSession.ValidateAuthenticator(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := validateSession.ValidateAuthenticator(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}); err != ErrContextReuse {
|
||||
t.Fatalf("got %v, want %v", err, ErrContextReuse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyAuthenticatorRoundTrip(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
ctx, _ := NewRandomContext(10)
|
||||
req := &AuthenticatorRequest{
|
||||
Type: HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
|
||||
},
|
||||
}
|
||||
|
||||
srvState := srv.ConnectionState()
|
||||
auth, err := CreateAuthenticator(&srvState, RoleServer, req, tls.Certificate{}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cliState := cli.ConnectionState()
|
||||
res, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !res.Empty {
|
||||
t.Fatalf("expected empty authenticator result")
|
||||
}
|
||||
if len(res.Chain) != 0 {
|
||||
t.Fatalf("expected no certificate chain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectCertificateMessageWithEmptyEntries(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
ctx, _ := NewRandomContext(10)
|
||||
req := &AuthenticatorRequest{
|
||||
Type: HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
|
||||
},
|
||||
}
|
||||
|
||||
certBytes, err := (CertificateMessage{Context: ctx}).Marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cvBytes, err := (CertificateVerifyMessage{
|
||||
Algorithm: uint16(tls.ECDSAWithP256AndSHA256),
|
||||
Signature: []byte{0x01},
|
||||
}).Marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
finBytes, err := (FinishedMessage{VerifyData: []byte{0x01}}).Marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
auth := append(append(certBytes, cvBytes...), finBytes...)
|
||||
|
||||
cliState := cli.ConnectionState()
|
||||
if _, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil); err != ErrUnsupportedHandshakeType {
|
||||
t.Fatalf("got %v, want %v", err, ErrUnsupportedHandshakeType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectSpontaneousClientAuthenticator(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
cliState := cli.ConnectionState()
|
||||
if _, err := CreateAuthenticator(&cliState, RoleClient, nil, cert, nil); err != ErrInvalidRole {
|
||||
t.Fatalf("got %v, want %v", err, ErrInvalidRole)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestParsers(t *testing.T) {
|
||||
oidDER, err := asn1.Marshal(asn1.ObjectIdentifier{2, 5, 4, 3})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oidFilterPayload := append([]byte{byte(len(oidDER))}, oidDER...)
|
||||
oidFilterPayload = append(oidFilterPayload, 0x00, 0x02, 'o', 'k')
|
||||
req := AuthenticatorRequest{
|
||||
Type: HandshakeTypeCertificateRequest,
|
||||
Context: []byte{1, 2, 3},
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x04, 0x04, 0x03, 0x08, 0x07}},
|
||||
{Type: SignatureAlgorithmsCertExtensionType, Data: []byte{0x00, 0x02, 0x08, 0x07}},
|
||||
{Type: CertificateAuthoritiesExtensionType, Data: []byte{0x00, 0x06, 0x00, 0x04, 't', 'e', 's', 't'}},
|
||||
{Type: OIDFiltersExtensionType, Data: append([]byte{0x00, byte(len(oidFilterPayload))}, oidFilterPayload...)},
|
||||
},
|
||||
}
|
||||
|
||||
if got, ok := req.SignatureSchemes(); !ok || len(got) != 2 || got[0] != uint16(tls.ECDSAWithP256AndSHA256) || got[1] != uint16(tls.Ed25519) {
|
||||
t.Fatalf("unexpected signature schemes: %v %v", got, ok)
|
||||
}
|
||||
if got, ok := req.SignatureSchemesCert(); !ok || len(got) != 1 || got[0] != uint16(tls.Ed25519) {
|
||||
t.Fatalf("unexpected signature_algorithms_cert: %v %v", got, ok)
|
||||
}
|
||||
if got, ok := req.CertificateAuthorities(); !ok || len(got) != 1 || string(got[0]) != "test" {
|
||||
t.Fatalf("unexpected certificate authorities: %q %v", got, ok)
|
||||
}
|
||||
if got, ok := req.OIDFilters(); !ok || len(got) != 1 || !got[0].OID.Equal(asn1.ObjectIdentifier{2, 5, 4, 3}) || string(got[0].Values) != "ok" {
|
||||
t.Fatalf("unexpected oid filters: %#v %v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectLeafExtensionNotPermittedByRequest(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, _ := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
|
||||
ctx, _ := NewRandomContext(8)
|
||||
req := &AuthenticatorRequest{
|
||||
Type: HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []Extension{
|
||||
{Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}},
|
||||
},
|
||||
}
|
||||
|
||||
srvState := srv.ConnectionState()
|
||||
ext := Extension{Type: 0x1234, Data: []byte{0x00}}
|
||||
if _, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext}); err == nil {
|
||||
t.Fatalf("expected policy error for unpermitted leaf extension")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpontaneousPolicyPermitsCertificateExtension(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, cli := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
payloadBytes, err := attestation.MarshalPayload(attestation.Payload{
|
||||
Version: 1,
|
||||
Evidence: []byte("dummy-attestation-report"),
|
||||
MediaType: "application/eat+cwt",
|
||||
Binder: attestation.AttestationBinder{
|
||||
ExporterLabel: attestation.ExporterLabelAttestation,
|
||||
AIKPubHash: []byte("placeholder-aik"),
|
||||
Binding: []byte("placeholder-binding"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ext, err := CMWAttestationDataExtension(payloadBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
policy := &SpontaneousAuthenticatorPolicy{
|
||||
AllowedSignatureSchemes: []uint16{uint16(tls.ECDSAWithP256AndSHA256)},
|
||||
AllowedCertificateExtensions: []uint16{CMWAttestationExtensionType},
|
||||
}
|
||||
|
||||
srvState := srv.ConnectionState()
|
||||
auth, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, []Extension{ext})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(auth) == 0 {
|
||||
t.Fatalf("expected authenticator bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpontaneousPolicyRejectsCertificateExtension(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, _ := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
|
||||
ext, err := CMWAttestationDataExtension([]byte("dummy-attestation-report"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
policy := &SpontaneousAuthenticatorPolicy{
|
||||
AllowedSignatureSchemes: []uint16{uint16(tls.ECDSAWithP256AndSHA256)},
|
||||
}
|
||||
|
||||
srvState := srv.ConnectionState()
|
||||
if _, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, []Extension{ext}); err == nil {
|
||||
t.Fatalf("expected policy error for unpermitted spontaneous extension")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpontaneousPolicyRejectsSignatureScheme(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
srv, _ := tlsPair(t, cert)
|
||||
defer srv.Close()
|
||||
|
||||
policy := &SpontaneousAuthenticatorPolicy{
|
||||
AllowedSignatureSchemes: []uint16{uint16(tls.Ed25519)},
|
||||
}
|
||||
|
||||
srvState := srv.ConnectionState()
|
||||
if _, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, nil); err != ErrUnsupportedSignatureScheme {
|
||||
t.Fatalf("got %v, want %v", err, ErrUnsupportedSignatureScheme)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
type CertificateMessage struct {
|
||||
Context []byte
|
||||
Entries []CertificateEntry
|
||||
}
|
||||
|
||||
type CertificateEntry struct {
|
||||
CertDER []byte
|
||||
Extensions []Extension
|
||||
}
|
||||
|
||||
func (m CertificateMessage) Marshal() ([]byte, error) {
|
||||
if len(m.Context) > 255 {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
|
||||
var listPayload []byte
|
||||
for _, e := range m.Entries {
|
||||
if len(e.CertDER) == 0 || len(e.CertDER) > 0xFFFFFF {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
extVec, err := MarshalExtensions(e.Extensions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry := make([]byte, 3+len(e.CertDER)+len(extVec))
|
||||
putUint24(entry[0:3], uint32(len(e.CertDER)))
|
||||
copy(entry[3:], e.CertDER)
|
||||
copy(entry[3+len(e.CertDER):], extVec)
|
||||
listPayload = append(listPayload, entry...)
|
||||
}
|
||||
|
||||
body := make([]byte, 1+len(m.Context)+3+len(listPayload))
|
||||
body[0] = byte(len(m.Context))
|
||||
copy(body[1:], m.Context)
|
||||
putUint24(body[1+len(m.Context):1+len(m.Context)+3], uint32(len(listPayload)))
|
||||
copy(body[1+len(m.Context)+3:], listPayload)
|
||||
|
||||
return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeCertificate, Body: body})
|
||||
}
|
||||
|
||||
func UnmarshalCertificateMessage(handshakeBytes []byte) (CertificateMessage, []byte, error) {
|
||||
hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes)
|
||||
if err != nil {
|
||||
return CertificateMessage{}, nil, err
|
||||
}
|
||||
if len(rest) != 0 || hm.Type != HandshakeTypeCertificate {
|
||||
return CertificateMessage{}, nil, ErrInvalidLength
|
||||
}
|
||||
if len(hm.Body) < 1 {
|
||||
return CertificateMessage{}, nil, ErrTruncated
|
||||
}
|
||||
|
||||
ctxLen := int(hm.Body[0])
|
||||
if len(hm.Body) < 1+ctxLen+3 {
|
||||
return CertificateMessage{}, nil, ErrTruncated
|
||||
}
|
||||
ctx := append([]byte(nil), hm.Body[1:1+ctxLen]...)
|
||||
|
||||
listLen := int(readUint24(hm.Body[1+ctxLen : 1+ctxLen+3]))
|
||||
if len(hm.Body) != 1+ctxLen+3+listLen {
|
||||
return CertificateMessage{}, nil, ErrInvalidLength
|
||||
}
|
||||
list := hm.Body[1+ctxLen+3:]
|
||||
|
||||
var entries []CertificateEntry
|
||||
for i := 0; i < len(list); {
|
||||
if len(list)-i < 3 {
|
||||
return CertificateMessage{}, nil, ErrTruncated
|
||||
}
|
||||
certLen := int(readUint24(list[i : i+3]))
|
||||
i += 3
|
||||
if certLen <= 0 || certLen > len(list)-i {
|
||||
return CertificateMessage{}, nil, ErrInvalidLength
|
||||
}
|
||||
certDER := append([]byte(nil), list[i:i+certLen]...)
|
||||
i += certLen
|
||||
|
||||
exts, leftover, err := UnmarshalExtensions(list[i:])
|
||||
if err != nil {
|
||||
return CertificateMessage{}, nil, err
|
||||
}
|
||||
consumed := len(list[i:]) - len(leftover)
|
||||
i += consumed
|
||||
|
||||
entries = append(entries, CertificateEntry{CertDER: certDER, Extensions: exts})
|
||||
}
|
||||
|
||||
raw, _ := MarshalHandshakeMessage(hm)
|
||||
return CertificateMessage{Context: ctx, Entries: entries}, raw, nil
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
)
|
||||
|
||||
type CertificateVerifyMessage struct {
|
||||
Algorithm uint16
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
func (m CertificateVerifyMessage) Marshal() ([]byte, error) {
|
||||
if len(m.Signature) > 0xFFFF {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
body := make([]byte, 4+len(m.Signature))
|
||||
putUint16(body[0:2], m.Algorithm)
|
||||
putUint16(body[2:4], uint16(len(m.Signature)))
|
||||
copy(body[4:], m.Signature)
|
||||
return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeCertificateVerify, Body: body})
|
||||
}
|
||||
|
||||
func UnmarshalCertificateVerifyMessage(handshakeBytes []byte) (CertificateVerifyMessage, []byte, error) {
|
||||
hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes)
|
||||
if err != nil {
|
||||
return CertificateVerifyMessage{}, nil, err
|
||||
}
|
||||
if len(rest) != 0 || hm.Type != HandshakeTypeCertificateVerify {
|
||||
return CertificateVerifyMessage{}, nil, ErrInvalidLength
|
||||
}
|
||||
if len(hm.Body) < 4 {
|
||||
return CertificateVerifyMessage{}, nil, ErrTruncated
|
||||
}
|
||||
alg := readUint16(hm.Body[0:2])
|
||||
sigLen := int(readUint16(hm.Body[2:4]))
|
||||
if len(hm.Body) != 4+sigLen {
|
||||
return CertificateVerifyMessage{}, nil, ErrInvalidLength
|
||||
}
|
||||
sig := append([]byte(nil), hm.Body[4:]...)
|
||||
raw, _ := MarshalHandshakeMessage(hm)
|
||||
return CertificateVerifyMessage{Algorithm: alg, Signature: sig}, raw, nil
|
||||
}
|
||||
|
||||
var eaContextString = []byte("Exported Authenticator")
|
||||
|
||||
func buildCertVerifyInput(transcriptHash []byte) []byte {
|
||||
prefix := bytes.Repeat([]byte{0x20}, 64)
|
||||
out := make([]byte, 0, len(prefix)+len(eaContextString)+1+len(transcriptHash))
|
||||
out = append(out, prefix...)
|
||||
out = append(out, eaContextString...)
|
||||
out = append(out, 0x00)
|
||||
out = append(out, transcriptHash...)
|
||||
return out
|
||||
}
|
||||
|
||||
func signCertVerify(priv any, scheme uint16, transcriptHash []byte) ([]byte, error) {
|
||||
info, err := signatureSchemeInfo(scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := buildCertVerifyInput(transcriptHash)
|
||||
|
||||
switch info.Alg {
|
||||
case sigAlgECDSA:
|
||||
k, ok := priv.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedSignatureScheme
|
||||
}
|
||||
h := info.Hash.New()
|
||||
h.Write(msg)
|
||||
return ecdsa.SignASN1(rand.Reader, k, h.Sum(nil))
|
||||
|
||||
case sigAlgRSAPSS:
|
||||
k, ok := priv.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedSignatureScheme
|
||||
}
|
||||
h := info.Hash.New()
|
||||
h.Write(msg)
|
||||
return rsa.SignPSS(rand.Reader, k, info.Hash, h.Sum(nil),
|
||||
&rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: info.Hash})
|
||||
|
||||
case sigAlgEd25519:
|
||||
k, ok := priv.(ed25519.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedSignatureScheme
|
||||
}
|
||||
return ed25519.Sign(k, msg), nil
|
||||
|
||||
default:
|
||||
return nil, ErrUnsupportedSignatureScheme
|
||||
}
|
||||
}
|
||||
|
||||
func verifyCertVerify(pub any, scheme uint16, transcriptHash []byte, signature []byte) error {
|
||||
info, err := signatureSchemeInfo(scheme)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := buildCertVerifyInput(transcriptHash)
|
||||
|
||||
switch info.Alg {
|
||||
case sigAlgECDSA:
|
||||
k, ok := pub.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return ErrUnsupportedSignatureScheme
|
||||
}
|
||||
h := info.Hash.New()
|
||||
h.Write(msg)
|
||||
if !ecdsa.VerifyASN1(k, h.Sum(nil), signature) {
|
||||
return ErrSignatureMismatch
|
||||
}
|
||||
return nil
|
||||
|
||||
case sigAlgRSAPSS:
|
||||
k, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return ErrUnsupportedSignatureScheme
|
||||
}
|
||||
h := info.Hash.New()
|
||||
h.Write(msg)
|
||||
if err := rsa.VerifyPSS(k, info.Hash, h.Sum(nil), signature,
|
||||
&rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: info.Hash}); err != nil {
|
||||
return ErrSignatureMismatch
|
||||
}
|
||||
return nil
|
||||
|
||||
case sigAlgEd25519:
|
||||
k, ok := pub.(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return ErrUnsupportedSignatureScheme
|
||||
}
|
||||
if !ed25519.Verify(k, msg, signature) {
|
||||
return ErrSignatureMismatch
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
return ErrUnsupportedSignatureScheme
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
const CMWAttestationExtensionType uint16 = 0xFF00
|
||||
|
||||
func CMWAttestationOfferExtension() Extension {
|
||||
return Extension{Type: CMWAttestationExtensionType, Data: nil}
|
||||
}
|
||||
|
||||
func CMWAttestationDataExtension(cmw []byte) (Extension, error) {
|
||||
if len(cmw) == 0 || len(cmw) > 0xFFFF {
|
||||
return Extension{}, ErrInvalidLength
|
||||
}
|
||||
data := make([]byte, 2+len(cmw))
|
||||
putUint16(data[0:2], uint16(len(cmw)))
|
||||
copy(data[2:], cmw)
|
||||
return Extension{Type: CMWAttestationExtensionType, Data: data}, nil
|
||||
}
|
||||
|
||||
func ExtractCMWAttestationFromExtensions(exts []Extension) ([]byte, bool, error) {
|
||||
for _, e := range exts {
|
||||
if e.Type != CMWAttestationExtensionType {
|
||||
continue
|
||||
}
|
||||
if len(e.Data) < 2 {
|
||||
return nil, true, ErrInvalidLength
|
||||
}
|
||||
l := int(readUint16(e.Data[0:2]))
|
||||
if l <= 0 || l != len(e.Data)-2 {
|
||||
return nil, true, ErrInvalidLength
|
||||
}
|
||||
return append([]byte(nil), e.Data[2:]...), true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func ValidateCMWAttestationPlacement(entries []CertificateEntry) error {
|
||||
for i, entry := range entries {
|
||||
for _, ext := range entry.Extensions {
|
||||
if ext.Type != CMWAttestationExtensionType {
|
||||
continue
|
||||
}
|
||||
if i != 0 {
|
||||
return ErrBadRequest
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
LabelClientAuthenticatorHandshakeContext = "EXPORTER-client authenticator handshake context"
|
||||
LabelServerAuthenticatorHandshakeContext = "EXPORTER-server authenticator handshake context"
|
||||
LabelClientAuthenticatorFinishedKey = "EXPORTER-client authenticator finished key"
|
||||
LabelServerAuthenticatorFinishedKey = "EXPORTER-server authenticator finished key"
|
||||
)
|
||||
|
||||
type Role uint8
|
||||
|
||||
const (
|
||||
RoleClient Role = iota + 1
|
||||
RoleServer
|
||||
)
|
||||
|
||||
func AuthenticatorHashTLS13(cipherSuite uint16) (crypto.Hash, error) {
|
||||
switch cipherSuite {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return crypto.SHA256, nil
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
return crypto.SHA384, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s (0x%04x)", ErrUnknownCipherSuite, tls.CipherSuiteName(cipherSuite), cipherSuite)
|
||||
}
|
||||
}
|
||||
|
||||
func ExportHandshakeContext(st *tls.ConnectionState, role Role) ([]byte, crypto.Hash, error) {
|
||||
if st.Version != tls.VersionTLS13 {
|
||||
return nil, 0, ErrNotTLS13
|
||||
}
|
||||
h, err := AuthenticatorHashTLS13(st.CipherSuite)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
label := LabelClientAuthenticatorHandshakeContext
|
||||
if role == RoleServer {
|
||||
label = LabelServerAuthenticatorHandshakeContext
|
||||
}
|
||||
out, err := st.ExportKeyingMaterial(label, nil, h.Size())
|
||||
return out, h, err
|
||||
}
|
||||
|
||||
func ExportFinishedKey(st *tls.ConnectionState, role Role) ([]byte, crypto.Hash, error) {
|
||||
if st.Version != tls.VersionTLS13 {
|
||||
return nil, 0, ErrNotTLS13
|
||||
}
|
||||
h, err := AuthenticatorHashTLS13(st.CipherSuite)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
label := LabelClientAuthenticatorFinishedKey
|
||||
if role == RoleServer {
|
||||
label = LabelServerAuthenticatorFinishedKey
|
||||
}
|
||||
out, err := st.ExportKeyingMaterial(label, nil, h.Size())
|
||||
return out, h, err
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
type Extension struct {
|
||||
Type uint16
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func MarshalExtensions(exts []Extension) ([]byte, error) {
|
||||
payloadLen := 0
|
||||
for _, e := range exts {
|
||||
if len(e.Data) > 0xFFFF {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
payloadLen += 4 + len(e.Data)
|
||||
if payloadLen > 0xFFFF {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
}
|
||||
out := make([]byte, 2+payloadLen)
|
||||
putUint16(out[0:2], uint16(payloadLen))
|
||||
off := 2
|
||||
for _, e := range exts {
|
||||
putUint16(out[off:off+2], e.Type)
|
||||
putUint16(out[off+2:off+4], uint16(len(e.Data)))
|
||||
copy(out[off+4:], e.Data)
|
||||
off += 4 + len(e.Data)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func UnmarshalExtensions(b []byte) (exts []Extension, rest []byte, err error) {
|
||||
if len(b) < 2 {
|
||||
return nil, nil, ErrTruncated
|
||||
}
|
||||
total := int(readUint16(b[0:2]))
|
||||
if len(b) < 2+total {
|
||||
return nil, nil, ErrTruncated
|
||||
}
|
||||
payload := b[2 : 2+total]
|
||||
rest = b[2+total:]
|
||||
i := 0
|
||||
for i < len(payload) {
|
||||
if len(payload)-i < 4 {
|
||||
return nil, nil, ErrTruncated
|
||||
}
|
||||
typ := readUint16(payload[i : i+2])
|
||||
l := int(readUint16(payload[i+2 : i+4]))
|
||||
i += 4
|
||||
if l < 0 || l > len(payload)-i {
|
||||
return nil, nil, ErrInvalidLength
|
||||
}
|
||||
data := make([]byte, l)
|
||||
copy(data, payload[i:i+l])
|
||||
i += l
|
||||
exts = append(exts, Extension{Type: typ, Data: data})
|
||||
}
|
||||
return exts, rest, nil
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
type FinishedMessage struct {
|
||||
VerifyData []byte
|
||||
}
|
||||
|
||||
func (m FinishedMessage) Marshal() ([]byte, error) {
|
||||
if len(m.VerifyData) == 0 {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeFinished, Body: append([]byte(nil), m.VerifyData...)})
|
||||
}
|
||||
|
||||
func UnmarshalFinishedMessage(handshakeBytes []byte) (FinishedMessage, []byte, error) {
|
||||
hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes)
|
||||
if err != nil {
|
||||
return FinishedMessage{}, nil, err
|
||||
}
|
||||
if len(rest) != 0 || hm.Type != HandshakeTypeFinished || len(hm.Body) == 0 {
|
||||
return FinishedMessage{}, nil, ErrInvalidLength
|
||||
}
|
||||
raw, _ := MarshalHandshakeMessage(hm)
|
||||
return FinishedMessage{VerifyData: append([]byte(nil), hm.Body...)}, raw, nil
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
const handshakeHeaderLen = 4 // 1 byte type + 3 byte uint24 len
|
||||
|
||||
const (
|
||||
HandshakeTypeCertificate uint8 = 11
|
||||
HandshakeTypeCertificateRequest uint8 = 13
|
||||
HandshakeTypeCertificateVerify uint8 = 15
|
||||
HandshakeTypeClientCertificateRequest uint8 = 17
|
||||
HandshakeTypeFinished uint8 = 20
|
||||
)
|
||||
|
||||
type HandshakeMessage struct {
|
||||
Type uint8
|
||||
Body []byte
|
||||
}
|
||||
|
||||
func MarshalHandshakeMessage(m HandshakeMessage) ([]byte, error) {
|
||||
if len(m.Body) > 0xFFFFFF {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
out := make([]byte, handshakeHeaderLen+len(m.Body))
|
||||
out[0] = m.Type
|
||||
putUint24(out[1:4], uint32(len(m.Body)))
|
||||
copy(out[4:], m.Body)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func UnmarshalHandshakeMessage(b []byte) (msg HandshakeMessage, rest []byte, err error) {
|
||||
if len(b) < handshakeHeaderLen {
|
||||
return HandshakeMessage{}, nil, ErrTruncated
|
||||
}
|
||||
t := b[0]
|
||||
n := int(readUint24(b[1:4]))
|
||||
if n < 0 || n > 0xFFFFFF {
|
||||
return HandshakeMessage{}, nil, ErrInvalidLength
|
||||
}
|
||||
if len(b) < handshakeHeaderLen+n {
|
||||
return HandshakeMessage{}, nil, ErrTruncated
|
||||
}
|
||||
body := make([]byte, n)
|
||||
copy(body, b[4:4+n])
|
||||
return HandshakeMessage{Type: t, Body: body}, b[4+n:], nil
|
||||
}
|
||||
|
||||
func putUint24(dst []byte, v uint32) { dst[0] = byte(v >> 16); dst[1] = byte(v >> 8); dst[2] = byte(v) }
|
||||
func readUint24(src []byte) uint32 {
|
||||
return (uint32(src[0]) << 16) | (uint32(src[1]) << 8) | uint32(src[2])
|
||||
}
|
||||
|
||||
func putUint16(dst []byte, v uint16) { dst[0] = byte(v >> 8); dst[1] = byte(v) }
|
||||
func readUint16(src []byte) uint16 { return (uint16(src[0]) << 8) | uint16(src[1]) }
|
||||
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
type SpontaneousAuthenticatorPolicy struct {
|
||||
AllowedSignatureSchemes []uint16
|
||||
AllowedCertificateExtensions []uint16
|
||||
}
|
||||
|
||||
func RequestPermitsCertificateExtension(req *AuthenticatorRequest, typ uint16) bool {
|
||||
if req == nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range req.Extensions {
|
||||
if e.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func PolicyPermitsCertificateExtension(policy *SpontaneousAuthenticatorPolicy, typ uint16) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
for _, allowed := range policy.AllowedCertificateExtensions {
|
||||
if allowed == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyPermitsSignatureScheme(policy *SpontaneousAuthenticatorPolicy, scheme uint16) bool {
|
||||
if policy == nil || len(policy.AllowedSignatureSchemes) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, allowed := range policy.AllowedSignatureSchemes {
|
||||
if allowed == scheme {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateCertificateEntryExtensions(exts []Extension, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy) error {
|
||||
for _, e := range exts {
|
||||
if req != nil {
|
||||
if !RequestPermitsCertificateExtension(req, e.Type) {
|
||||
return ErrBadRequest
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !PolicyPermitsCertificateExtension(policy, e.Type) {
|
||||
return ErrBadRequest
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/asn1"
|
||||
)
|
||||
|
||||
const (
|
||||
SignatureAlgorithmsExtensionType uint16 = 0x000d
|
||||
ServerNameExtensionType uint16 = 0x0000
|
||||
CertificateAuthoritiesExtensionType uint16 = 0x002f
|
||||
OIDFiltersExtensionType uint16 = 0x0030
|
||||
SignatureAlgorithmsCertExtensionType uint16 = 0x0032
|
||||
)
|
||||
|
||||
type AuthenticatorRequest struct {
|
||||
Type uint8
|
||||
Context []byte
|
||||
Extensions []Extension
|
||||
}
|
||||
|
||||
type OIDFilter struct {
|
||||
OID asn1.ObjectIdentifier
|
||||
Values []byte
|
||||
}
|
||||
|
||||
func NewRandomContext(n int) ([]byte, error) {
|
||||
if n <= 0 || n > 255 {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
return b, err
|
||||
}
|
||||
|
||||
func (r AuthenticatorRequest) Marshal() ([]byte, error) {
|
||||
if r.Type != HandshakeTypeCertificateRequest && r.Type != HandshakeTypeClientCertificateRequest {
|
||||
return nil, ErrUnsupportedHandshakeType
|
||||
}
|
||||
if len(r.Context) > 255 {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
extVec, err := MarshalExtensions(r.Extensions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body := make([]byte, 1+len(r.Context)+len(extVec))
|
||||
body[0] = byte(len(r.Context))
|
||||
copy(body[1:], r.Context)
|
||||
copy(body[1+len(r.Context):], extVec)
|
||||
return MarshalHandshakeMessage(HandshakeMessage{Type: r.Type, Body: body})
|
||||
}
|
||||
|
||||
func UnmarshalAuthenticatorRequest(handshakeBytes []byte) (AuthenticatorRequest, []byte, error) {
|
||||
hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes)
|
||||
if err != nil {
|
||||
return AuthenticatorRequest{}, nil, err
|
||||
}
|
||||
if hm.Type != HandshakeTypeCertificateRequest && hm.Type != HandshakeTypeClientCertificateRequest {
|
||||
return AuthenticatorRequest{}, nil, ErrUnsupportedHandshakeType
|
||||
}
|
||||
if len(hm.Body) < 1 {
|
||||
return AuthenticatorRequest{}, nil, ErrTruncated
|
||||
}
|
||||
ctxLen := int(hm.Body[0])
|
||||
if len(hm.Body) < 1+ctxLen {
|
||||
return AuthenticatorRequest{}, nil, ErrTruncated
|
||||
}
|
||||
ctx := append([]byte(nil), hm.Body[1:1+ctxLen]...)
|
||||
exts, leftover, err := UnmarshalExtensions(hm.Body[1+ctxLen:])
|
||||
if err != nil {
|
||||
return AuthenticatorRequest{}, nil, err
|
||||
}
|
||||
if len(leftover) != 0 {
|
||||
return AuthenticatorRequest{}, nil, ErrInvalidLength
|
||||
}
|
||||
return AuthenticatorRequest{
|
||||
Type: hm.Type,
|
||||
Context: ctx,
|
||||
Extensions: exts,
|
||||
}, rest, nil
|
||||
}
|
||||
|
||||
func (r AuthenticatorRequest) SignatureSchemes() ([]uint16, bool) {
|
||||
return parseSignatureSchemesExtension(r.Extensions, SignatureAlgorithmsExtensionType)
|
||||
}
|
||||
|
||||
func (r AuthenticatorRequest) SignatureSchemesCert() ([]uint16, bool) {
|
||||
return parseSignatureSchemesExtension(r.Extensions, SignatureAlgorithmsCertExtensionType)
|
||||
}
|
||||
|
||||
func (r AuthenticatorRequest) CertificateAuthorities() ([][]byte, bool) {
|
||||
for _, e := range r.Extensions {
|
||||
if e.Type != CertificateAuthoritiesExtensionType {
|
||||
continue
|
||||
}
|
||||
if len(e.Data) < 2 {
|
||||
return nil, false
|
||||
}
|
||||
total := int(readUint16(e.Data[0:2]))
|
||||
if total < 3 || len(e.Data) != 2+total {
|
||||
return nil, false
|
||||
}
|
||||
var out [][]byte
|
||||
for off := 2; off < len(e.Data); {
|
||||
if len(e.Data)-off < 2 {
|
||||
return nil, false
|
||||
}
|
||||
l := int(readUint16(e.Data[off : off+2]))
|
||||
off += 2
|
||||
if l == 0 || l > len(e.Data)-off {
|
||||
return nil, false
|
||||
}
|
||||
out = append(out, append([]byte(nil), e.Data[off:off+l]...))
|
||||
off += l
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r AuthenticatorRequest) OIDFilters() ([]OIDFilter, bool) {
|
||||
for _, e := range r.Extensions {
|
||||
if e.Type != OIDFiltersExtensionType {
|
||||
continue
|
||||
}
|
||||
if len(e.Data) < 2 {
|
||||
return nil, false
|
||||
}
|
||||
total := int(readUint16(e.Data[0:2]))
|
||||
if len(e.Data) != 2+total {
|
||||
return nil, false
|
||||
}
|
||||
var out []OIDFilter
|
||||
for off := 2; off < len(e.Data); {
|
||||
if len(e.Data)-off < 1 {
|
||||
return nil, false
|
||||
}
|
||||
oidLen := int(e.Data[off])
|
||||
off++
|
||||
if oidLen == 0 || oidLen > len(e.Data)-off {
|
||||
return nil, false
|
||||
}
|
||||
rawOID := append([]byte(nil), e.Data[off:off+oidLen]...)
|
||||
off += oidLen
|
||||
var oid asn1.ObjectIdentifier
|
||||
if _, err := asn1.Unmarshal(rawOID, &oid); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if len(e.Data)-off < 2 {
|
||||
return nil, false
|
||||
}
|
||||
valLen := int(readUint16(e.Data[off : off+2]))
|
||||
off += 2
|
||||
if valLen > len(e.Data)-off {
|
||||
return nil, false
|
||||
}
|
||||
values := append([]byte(nil), e.Data[off:off+valLen]...)
|
||||
off += valLen
|
||||
out = append(out, OIDFilter{OID: oid, Values: values})
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func parseSignatureSchemesExtension(exts []Extension, typ uint16) ([]uint16, bool) {
|
||||
for _, e := range exts {
|
||||
if e.Type != typ {
|
||||
continue
|
||||
}
|
||||
if len(e.Data) < 2 {
|
||||
return nil, false
|
||||
}
|
||||
vecLen := int(readUint16(e.Data[0:2]))
|
||||
if vecLen < 2 || vecLen%2 != 0 || len(e.Data) != 2+vecLen {
|
||||
return nil, false
|
||||
}
|
||||
out := make([]uint16, 0, vecLen/2)
|
||||
for off := 2; off < len(e.Data); off += 2 {
|
||||
out = append(out, readUint16(e.Data[off:off+2]))
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func SignatureAlgorithmsExtension(schemes []uint16) (Extension, error) {
|
||||
return marshalSignatureSchemesExtension(SignatureAlgorithmsExtensionType, schemes)
|
||||
}
|
||||
|
||||
func SignatureAlgorithmsCertExtension(schemes []uint16) (Extension, error) {
|
||||
return marshalSignatureSchemesExtension(SignatureAlgorithmsCertExtensionType, schemes)
|
||||
}
|
||||
|
||||
func marshalSignatureSchemesExtension(typ uint16, schemes []uint16) (Extension, error) {
|
||||
if len(schemes) == 0 {
|
||||
return Extension{}, ErrInvalidLength
|
||||
}
|
||||
if len(schemes) > 0x7fff {
|
||||
return Extension{}, ErrInvalidLength
|
||||
}
|
||||
data := make([]byte, 2+2*len(schemes))
|
||||
putUint16(data[0:2], uint16(2*len(schemes)))
|
||||
off := 2
|
||||
for _, s := range schemes {
|
||||
putUint16(data[off:off+2], s)
|
||||
off += 2
|
||||
}
|
||||
return Extension{Type: typ, Data: data}, nil
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"sync"
|
||||
|
||||
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
mu sync.Mutex
|
||||
used map[string]struct{}
|
||||
}
|
||||
|
||||
func NewSession() *Session {
|
||||
return &Session{used: make(map[string]struct{})}
|
||||
}
|
||||
|
||||
func (s *Session) MarkContextUsed(ctx []byte) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
key := string(ctx)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.used[key]; ok {
|
||||
return ErrContextReuse
|
||||
}
|
||||
s.used[key] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) CreateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
|
||||
return createAuthenticator(s, st, role, req, nil, identity, leafEntryExtensions)
|
||||
}
|
||||
|
||||
func (s *Session) ValidateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
|
||||
return validateAuthenticator(s, st, role, req, nil, nil, authBytes, verifyOpts)
|
||||
}
|
||||
|
||||
func (s *Session) CreateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) {
|
||||
return createAuthenticator(s, st, role, req, policy, identity, leafEntryExtensions)
|
||||
}
|
||||
|
||||
func (s *Session) ValidateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) {
|
||||
return validateAuthenticator(s, st, role, req, policy, nil, authBytes, verifyOpts)
|
||||
}
|
||||
|
||||
func (s *Session) ValidateAuthenticatorWithAttestation(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) {
|
||||
return validateAuthenticator(s, st, role, req, nil, &attPolicy, authBytes, verifyOpts)
|
||||
}
|
||||
|
||||
func (s *Session) ValidateAuthenticatorWithPolicies(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) {
|
||||
return validateAuthenticator(s, st, role, req, policy, &attPolicy, authBytes, verifyOpts)
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type sigAlg uint8
|
||||
|
||||
const (
|
||||
sigAlgECDSA sigAlg = iota + 1
|
||||
sigAlgRSAPSS
|
||||
sigAlgEd25519
|
||||
)
|
||||
|
||||
type sigSchemeInfo struct {
|
||||
Scheme uint16
|
||||
Alg sigAlg
|
||||
Hash crypto.Hash // 0 for Ed25519
|
||||
}
|
||||
|
||||
func signatureSchemeInfo(s uint16) (sigSchemeInfo, error) {
|
||||
switch s {
|
||||
case uint16(tls.ECDSAWithP256AndSHA256):
|
||||
return sigSchemeInfo{s, sigAlgECDSA, crypto.SHA256}, nil
|
||||
case uint16(tls.PSSWithSHA256):
|
||||
return sigSchemeInfo{s, sigAlgRSAPSS, crypto.SHA256}, nil
|
||||
case uint16(tls.Ed25519):
|
||||
return sigSchemeInfo{s, sigAlgEd25519, 0}, nil
|
||||
default:
|
||||
return sigSchemeInfo{}, fmt.Errorf("%w: 0x%04x", ErrUnsupportedSignatureScheme, s)
|
||||
}
|
||||
}
|
||||
|
||||
func chooseSignatureScheme(priv any, offered []uint16) (uint16, error) {
|
||||
compat := func(s uint16) bool {
|
||||
info, err := signatureSchemeInfo(s)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
switch info.Alg {
|
||||
case sigAlgECDSA:
|
||||
k, ok := priv.(*ecdsa.PrivateKey)
|
||||
return ok && k.Curve.Params().Name == "P-256"
|
||||
case sigAlgRSAPSS:
|
||||
_, ok := priv.(*rsa.PrivateKey)
|
||||
return ok
|
||||
case sigAlgEd25519:
|
||||
_, ok := priv.(ed25519.PrivateKey)
|
||||
return ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if len(offered) > 0 {
|
||||
for _, s := range offered {
|
||||
if compat(s) {
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
return 0, ErrUnsupportedSignatureScheme
|
||||
}
|
||||
|
||||
if k, ok := priv.(*ecdsa.PrivateKey); ok && k.Curve.Params().Name == "P-256" {
|
||||
return uint16(tls.ECDSAWithP256AndSHA256), nil
|
||||
}
|
||||
if _, ok := priv.(*rsa.PrivateKey); ok {
|
||||
return uint16(tls.PSSWithSHA256), nil
|
||||
}
|
||||
if _, ok := priv.(ed25519.PrivateKey); ok {
|
||||
return uint16(tls.Ed25519), nil
|
||||
}
|
||||
return 0, ErrUnsupportedSignatureScheme
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ea
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/subtle"
|
||||
)
|
||||
|
||||
func hashConcat(h crypto.Hash, chunks ...[]byte) []byte {
|
||||
hs := h.New()
|
||||
for _, c := range chunks {
|
||||
if len(c) == 0 {
|
||||
continue
|
||||
}
|
||||
hs.Write(c)
|
||||
}
|
||||
return hs.Sum(nil)
|
||||
}
|
||||
|
||||
func hmacSum(h crypto.Hash, key, data []byte) []byte {
|
||||
m := hmac.New(h.New, key)
|
||||
m.Write(data)
|
||||
return m.Sum(nil)
|
||||
}
|
||||
|
||||
func constantTimeEqual(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package attestation
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
ExporterLabelAttestation = "Attestation"
|
||||
ExporterLabelAttestationBinding = "Attestation Binding"
|
||||
)
|
||||
|
||||
const ExportedAttestationValueLen = 32
|
||||
|
||||
var errNotTLS13 = errors.New("attestation: not TLS 1.3")
|
||||
|
||||
func ExportAttestationValue(st *tls.ConnectionState, label string, contextValue []byte) ([]byte, crypto.Hash, error) {
|
||||
if st.Version != tls.VersionTLS13 {
|
||||
return nil, 0, errNotTLS13
|
||||
}
|
||||
h, err := authenticatorHashTLS13(st.CipherSuite)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
out, err := st.ExportKeyingMaterial(label, contextValue, ExportedAttestationValueLen)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return out, h, nil
|
||||
}
|
||||
|
||||
func PublicKeyBytes(leaf *x509.Certificate) ([]byte, error) {
|
||||
if leaf == nil {
|
||||
return nil, fmt.Errorf("nil leaf cert")
|
||||
}
|
||||
if len(leaf.RawSubjectPublicKeyInfo) > 0 {
|
||||
return leaf.RawSubjectPublicKeyInfo, nil
|
||||
}
|
||||
b, err := x509.MarshalPKIXPublicKey(leaf.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func AIKPublicKeyHash(h crypto.Hash, pubKey []byte) []byte {
|
||||
hs := h.New()
|
||||
hs.Write(pubKey)
|
||||
return hs.Sum(nil)
|
||||
}
|
||||
|
||||
func BindingValue(h crypto.Hash, pubKey, exportedValue []byte) []byte {
|
||||
hs := h.New()
|
||||
hs.Write(pubKey)
|
||||
hs.Write(exportedValue)
|
||||
return hs.Sum(nil)
|
||||
}
|
||||
|
||||
func ComputeBinding(st *tls.ConnectionState, label string, certificateRequestContext []byte, leaf *x509.Certificate) (exportedValue, aikPubHash, binding []byte, err error) {
|
||||
exportedValue, h, err := ExportAttestationValue(st, label, certificateRequestContext)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
pub, err := PublicKeyBytes(leaf)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
aikPubHash = AIKPublicKeyHash(h, pub)
|
||||
binding = BindingValue(h, pub, exportedValue)
|
||||
return exportedValue, aikPubHash, binding, nil
|
||||
}
|
||||
|
||||
func authenticatorHashTLS13(cipherSuite uint16) (crypto.Hash, error) {
|
||||
switch cipherSuite {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return crypto.SHA256, nil
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
return crypto.SHA384, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("attestation: unknown cipher suite: %s (0x%04x)", tls.CipherSuiteName(cipherSuite), cipherSuite)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package attestation
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type stubEvidenceVerifier struct {
|
||||
called bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubEvidenceVerifier) VerifyEvidence(evidence []byte) error {
|
||||
s.called = true
|
||||
return s.err
|
||||
}
|
||||
|
||||
type stubResultsVerifier struct {
|
||||
called bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubResultsVerifier) VerifyAttestationResults(results []byte) error {
|
||||
s.called = true
|
||||
return s.err
|
||||
}
|
||||
|
||||
func makeCert(t *testing.T) (tls.Certificate, *x509.Certificate) {
|
||||
t.Helper()
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "binding"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: priv}, leaf
|
||||
}
|
||||
|
||||
func tls13Client(t *testing.T, cert tls.Certificate) (*tls.Conn, *tls.Conn) {
|
||||
t.Helper()
|
||||
srvConf := &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13}
|
||||
cliConf := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13}
|
||||
a, b := net.Pipe()
|
||||
srv := tls.Server(a, srvConf)
|
||||
cli := tls.Client(b, cliConf)
|
||||
errCh := make(chan error, 2)
|
||||
go func() { errCh <- srv.Handshake() }()
|
||||
go func() { errCh <- cli.Handshake() }()
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("handshake: %v", err)
|
||||
}
|
||||
}
|
||||
return srv, cli
|
||||
}
|
||||
|
||||
func TestComputeBindingDeterministic(t *testing.T) {
|
||||
cert, leaf := makeCert(t)
|
||||
srv, cli := tls13Client(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
st := cli.ConnectionState()
|
||||
ctx := []byte{1, 2, 3, 4}
|
||||
|
||||
ev1, aik1, b1, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ev2, aik2, b2, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(ev1) != ExportedAttestationValueLen {
|
||||
t.Fatalf("unexpected exported len: %d", len(ev1))
|
||||
}
|
||||
if !bytes.Equal(ev1, ev2) || !bytes.Equal(aik1, aik2) || !bytes.Equal(b1, b2) {
|
||||
t.Fatalf("expected deterministic outputs for same conn+context")
|
||||
}
|
||||
if bytes.Equal(aik1, b1) {
|
||||
t.Fatalf("unexpected aik == binding")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadRoundTrip(t *testing.T) {
|
||||
payload := Payload{
|
||||
Version: 1,
|
||||
MediaType: "application/eat+cwt",
|
||||
Evidence: []byte("evidence"),
|
||||
Binder: AttestationBinder{
|
||||
ExporterLabel: ExporterLabelAttestation,
|
||||
AIKPubHash: []byte("aik"),
|
||||
Binding: []byte("binding"),
|
||||
},
|
||||
}
|
||||
|
||||
raw, err := MarshalPayload(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parsed, err := ParsePayload(raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if parsed.Version != 1 || parsed.MediaType != "application/eat+cwt" || string(parsed.Evidence) != "evidence" {
|
||||
t.Fatalf("unexpected parsed payload: %#v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPayloadSuccess(t *testing.T) {
|
||||
cert, leaf := makeCert(t)
|
||||
srv, cli := tls13Client(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
st := cli.ConnectionState()
|
||||
ctx := []byte{1, 2, 3, 4}
|
||||
_, aik, binding, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ev := &stubEvidenceVerifier{}
|
||||
rv := &stubResultsVerifier{}
|
||||
payload := &Payload{
|
||||
Version: 1,
|
||||
Evidence: []byte("evidence"),
|
||||
AttestationResults: []byte("results"),
|
||||
Binder: AttestationBinder{
|
||||
ExporterLabel: ExporterLabelAttestation,
|
||||
AIKPubHash: aik,
|
||||
Binding: binding,
|
||||
},
|
||||
}
|
||||
|
||||
verified, err := VerifyPayload(&st, ExporterLabelAttestation, ctx, leaf, payload, VerificationPolicy{
|
||||
EvidenceVerifier: ev,
|
||||
ResultsVerifier: rv,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !verified.BindingVerified || !verified.EvidenceVerified || !verified.ResultsVerified {
|
||||
t.Fatalf("unexpected verification result: %#v", verified)
|
||||
}
|
||||
if !ev.called || !rv.called {
|
||||
t.Fatalf("expected both verifiers to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyBinderRejectsMismatch(t *testing.T) {
|
||||
cert, leaf := makeCert(t)
|
||||
srv, cli := tls13Client(t, cert)
|
||||
defer srv.Close()
|
||||
defer cli.Close()
|
||||
|
||||
st := cli.ConnectionState()
|
||||
ctx := []byte{1, 2, 3, 4}
|
||||
_, aik, binding, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
binding[0] ^= 0xff
|
||||
|
||||
err = VerifyBinder(&st, ExporterLabelAttestation, ctx, leaf, AttestationBinder{
|
||||
AIKPubHash: aik,
|
||||
Binding: binding,
|
||||
})
|
||||
if err != ErrBindingMismatch {
|
||||
t.Fatalf("got %v, want %v", err, ErrBindingMismatch)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package attestation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMalformedPayload = errors.New("attestation: malformed payload")
|
||||
ErrMissingStatement = errors.New("attestation: missing evidence or attestation results")
|
||||
ErrMissingBinder = errors.New("attestation: missing attestation binder")
|
||||
ErrAIKPubHashMismatch = errors.New("attestation: AIK public key hash mismatch")
|
||||
ErrBindingMismatch = errors.New("attestation: attestation binding mismatch")
|
||||
ErrEvidenceVerificationMissing = errors.New("attestation: evidence verifier not configured")
|
||||
ErrResultsVerificationMissing = errors.New("attestation: attestation results verifier not configured")
|
||||
)
|
||||
|
||||
// Payload models the attestation document carried inside the EA certificate-entry extension.
|
||||
// It intentionally separates the attestation statement from the TLS binding material.
|
||||
type Payload struct {
|
||||
Version int `json:"version"`
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Evidence []byte `json:"evidence,omitempty"`
|
||||
AttestationResults []byte `json:"attestation_results,omitempty"`
|
||||
Binder AttestationBinder `json:"binder"`
|
||||
}
|
||||
|
||||
type AttestationBinder struct {
|
||||
ExporterLabel string `json:"exporter_label,omitempty"`
|
||||
AIKPubHash []byte `json:"aik_pub_hash,omitempty"`
|
||||
Binding []byte `json:"binding,omitempty"`
|
||||
}
|
||||
|
||||
type VerifiedPayload struct {
|
||||
Payload *Payload
|
||||
EvidenceVerified bool
|
||||
ResultsVerified bool
|
||||
BindingVerified bool
|
||||
UsedExporterLabel string
|
||||
}
|
||||
|
||||
func (p *Payload) Validate() error {
|
||||
if p == nil {
|
||||
return ErrMalformedPayload
|
||||
}
|
||||
if len(p.Evidence) == 0 && len(p.AttestationResults) == 0 {
|
||||
return ErrMissingStatement
|
||||
}
|
||||
if len(p.Binder.AIKPubHash) == 0 || len(p.Binder.Binding) == 0 {
|
||||
return ErrMissingBinder
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Payload) NormalizedExporterLabel(defaultLabel string) string {
|
||||
if p == nil || p.Binder.ExporterLabel == "" {
|
||||
return defaultLabel
|
||||
}
|
||||
return p.Binder.ExporterLabel
|
||||
}
|
||||
|
||||
func MarshalPayload(p Payload) ([]byte, error) {
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.Version == 0 {
|
||||
p.Version = 1
|
||||
}
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
func ParsePayload(raw []byte) (*Payload, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, ErrMalformedPayload
|
||||
}
|
||||
var p Payload
|
||||
if err := json.Unmarshal(raw, &p); err != nil {
|
||||
return nil, ErrMalformedPayload
|
||||
}
|
||||
if err := p.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package attestation
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
type EvidenceVerifier interface {
|
||||
VerifyEvidence(evidence []byte) error
|
||||
}
|
||||
|
||||
type ResultsVerifier interface {
|
||||
VerifyAttestationResults(results []byte) error
|
||||
}
|
||||
|
||||
type VerificationPolicy struct {
|
||||
EvidenceVerifier EvidenceVerifier
|
||||
ResultsVerifier ResultsVerifier
|
||||
}
|
||||
|
||||
func VerifyPayload(st *tls.ConnectionState, defaultLabel string, certificateRequestContext []byte, leaf *x509.Certificate, payload *Payload, policy VerificationPolicy) (*VerifiedPayload, error) {
|
||||
if err := payload.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
verified := &VerifiedPayload{
|
||||
Payload: payload,
|
||||
UsedExporterLabel: payload.NormalizedExporterLabel(defaultLabel),
|
||||
}
|
||||
|
||||
if len(payload.Evidence) > 0 && policy.EvidenceVerifier != nil {
|
||||
if err := policy.EvidenceVerifier.VerifyEvidence(payload.Evidence); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verified.EvidenceVerified = true
|
||||
} else if len(payload.Evidence) > 0 {
|
||||
return nil, ErrEvidenceVerificationMissing
|
||||
}
|
||||
if len(payload.AttestationResults) > 0 && policy.ResultsVerifier != nil {
|
||||
if err := policy.ResultsVerifier.VerifyAttestationResults(payload.AttestationResults); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verified.ResultsVerified = true
|
||||
} else if len(payload.AttestationResults) > 0 {
|
||||
return nil, ErrResultsVerificationMissing
|
||||
}
|
||||
if err := VerifyBinder(st, verified.UsedExporterLabel, certificateRequestContext, leaf, payload.Binder); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verified.BindingVerified = true
|
||||
return verified, nil
|
||||
}
|
||||
|
||||
func VerifyBinder(st *tls.ConnectionState, label string, certificateRequestContext []byte, leaf *x509.Certificate, binder AttestationBinder) error {
|
||||
exportedValue, aikPubHash, binding, err := ComputeBinding(st, label, certificateRequestContext, leaf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = exportedValue
|
||||
if !equalBytes(aikPubHash, binder.AIKPubHash) {
|
||||
return ErrAIKPubHashMismatch
|
||||
}
|
||||
if !equalBytes(binding, binder.Binding) {
|
||||
return ErrBindingMismatch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func equalBytes(a, b []byte) bool {
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package atls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
|
||||
cocosattestation "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/veraison/corim/corim"
|
||||
)
|
||||
|
||||
type policyEvidenceVerifier struct {
|
||||
policyPath string
|
||||
}
|
||||
|
||||
func NewEvidenceVerifier(policyPath string) eaattestation.EvidenceVerifier {
|
||||
return &policyEvidenceVerifier{policyPath: policyPath}
|
||||
}
|
||||
|
||||
func (v *policyEvidenceVerifier) VerifyEvidence(evidence []byte) error {
|
||||
if v.policyPath == "" {
|
||||
return fmt.Errorf("atls: attestation policy path is not set")
|
||||
}
|
||||
claims, err := eat.DecodeCBOR(evidence, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("atls: failed to decode EAT evidence: %w", err)
|
||||
}
|
||||
manifest, err := loadCoRIM(v.policyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
verifier, err := platformVerifier(platformTypeFromClaims(claims.PlatformType))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return verifier.VerifyWithCoRIM(claims.RawReport, manifest)
|
||||
}
|
||||
|
||||
func loadCoRIM(path string) (*corim.UnsignedCorim, error) {
|
||||
corimBytes, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("atls: failed to read CoRIM file: %w", err)
|
||||
}
|
||||
|
||||
var sc corim.SignedCorim
|
||||
if err := sc.FromCOSE(corimBytes); err == nil {
|
||||
return &sc.UnsignedCorim, nil
|
||||
}
|
||||
|
||||
var uc corim.UnsignedCorim
|
||||
if err := uc.FromCBOR(corimBytes); err != nil {
|
||||
return nil, fmt.Errorf("atls: failed to parse CoRIM: %w", err)
|
||||
}
|
||||
return &uc, nil
|
||||
}
|
||||
|
||||
func platformTypeFromClaims(name string) cocosattestation.PlatformType {
|
||||
switch name {
|
||||
case "SNP":
|
||||
return cocosattestation.SNP
|
||||
case "TDX":
|
||||
return cocosattestation.TDX
|
||||
case "vTPM":
|
||||
return cocosattestation.VTPM
|
||||
case "SNP-vTPM":
|
||||
return cocosattestation.SNPvTPM
|
||||
case "Azure":
|
||||
return cocosattestation.Azure
|
||||
default:
|
||||
return cocosattestation.NoCC
|
||||
}
|
||||
}
|
||||
|
||||
func platformVerifier(platformType cocosattestation.PlatformType) (cocosattestation.Verifier, error) {
|
||||
switch platformType {
|
||||
case cocosattestation.SNP, cocosattestation.SNPvTPM, cocosattestation.VTPM:
|
||||
return vtpm.NewVerifier(nil), nil
|
||||
case cocosattestation.Azure:
|
||||
return azure.NewVerifier(nil), nil
|
||||
case cocosattestation.TDX:
|
||||
return tdx.NewVerifier(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("atls: unsupported platform type: %d", platformType)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,280 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package internaltransport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/ea"
|
||||
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
*tls.Conn
|
||||
Request *ea.AuthenticatorRequest
|
||||
ValidationResult *ea.ValidationResult
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
TLSConfig *tls.Config
|
||||
Session *ea.Session
|
||||
VerifyOptions *x509.VerifyOptions
|
||||
AttestationPolicy eaattestation.VerificationPolicy
|
||||
Request *ea.AuthenticatorRequest
|
||||
RequestBuilder func() (*ea.AuthenticatorRequest, error)
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
TLSConfig *tls.Config
|
||||
Session *ea.Session
|
||||
Identity tls.Certificate
|
||||
BuildLeafExtensions func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) ([]ea.Extension, error)
|
||||
}
|
||||
|
||||
func Dial(network, address string, cfg *ClientConfig) (*Conn, error) {
|
||||
return DialWithDialer(new(net.Dialer), network, address, cfg)
|
||||
}
|
||||
|
||||
func DialContext(ctx context.Context, network, address string, cfg *ClientConfig) (*Conn, error) {
|
||||
return DialContextWithDialer(ctx, new(net.Dialer), network, address, cfg)
|
||||
}
|
||||
|
||||
func DialWithDialer(d *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) {
|
||||
if cfg == nil || cfg.TLSConfig == nil {
|
||||
return nil, fmt.Errorf("atls: missing client TLS config")
|
||||
}
|
||||
|
||||
rawConn, err := d.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(rawConn, cfg.TLSConfig.Clone())
|
||||
conn, err := Client(tlsConn, cfg)
|
||||
if err != nil {
|
||||
_ = tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func DialContextWithDialer(ctx context.Context, d *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) {
|
||||
if cfg == nil || cfg.TLSConfig == nil {
|
||||
return nil, fmt.Errorf("atls: missing client TLS config")
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, fmt.Errorf("atls: missing client context")
|
||||
}
|
||||
|
||||
rawConn, err := d.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(rawConn, cfg.TLSConfig.Clone())
|
||||
conn, err := ClientContext(ctx, tlsConn, cfg)
|
||||
if err != nil {
|
||||
_ = tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func Client(tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) {
|
||||
return ClientContext(context.Background(), tlsConn, cfg)
|
||||
}
|
||||
|
||||
func ClientContext(ctx context.Context, tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("atls: missing client config")
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, fmt.Errorf("atls: missing client context")
|
||||
}
|
||||
var res *Conn
|
||||
err := withConnContext(ctx, tlsConn, func() error {
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := buildRequest(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reqBytes, err := req.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeFrame(tlsConn, frameTypeRequest, reqBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
frameType, authBytes, err := readFrame(tlsConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if frameType != frameTypeAuthenticator {
|
||||
return fmt.Errorf("atls: unexpected frame type %d", frameType)
|
||||
}
|
||||
|
||||
st := tlsConn.ConnectionState()
|
||||
var validation *ea.ValidationResult
|
||||
if cfg.Session != nil {
|
||||
validation, err = cfg.Session.ValidateAuthenticatorWithAttestation(&st, ea.RoleServer, req, authBytes, cfg.VerifyOptions, cfg.AttestationPolicy)
|
||||
} else {
|
||||
validation, err = ea.ValidateAuthenticatorWithAttestation(&st, ea.RoleServer, req, authBytes, cfg.VerifyOptions, cfg.AttestationPolicy)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res = &Conn{Conn: tlsConn, Request: req, ValidationResult: validation}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func Server(tlsConn *tls.Conn, cfg *ServerConfig) (*Conn, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("atls: missing server config")
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
frameType, reqBytes, err := readFrame(tlsConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if frameType != frameTypeRequest {
|
||||
return nil, fmt.Errorf("atls: unexpected frame type %d", frameType)
|
||||
}
|
||||
req, rest, err := ea.UnmarshalAuthenticatorRequest(reqBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("atls: trailing request bytes")
|
||||
}
|
||||
|
||||
st := tlsConn.ConnectionState()
|
||||
identity, err := resolveIdentity(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exts, err := buildServerExtensions(cfg, &st, &req, identity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var authBytes []byte
|
||||
if cfg.Session != nil {
|
||||
authBytes, err = cfg.Session.CreateAuthenticator(&st, ea.RoleServer, &req, identity, exts)
|
||||
} else {
|
||||
authBytes, err = ea.CreateAuthenticator(&st, ea.RoleServer, &req, identity, exts)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := writeFrame(tlsConn, frameTypeAuthenticator, authBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Conn{Conn: tlsConn, Request: &req}, nil
|
||||
}
|
||||
|
||||
func buildRequest(cfg *ClientConfig) (*ea.AuthenticatorRequest, error) {
|
||||
if cfg.RequestBuilder != nil {
|
||||
return cfg.RequestBuilder()
|
||||
}
|
||||
if cfg.Request != nil {
|
||||
return cfg.Request, nil
|
||||
}
|
||||
ctx, err := ea.NewRandomContext(32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sigExt, err := ea.SignatureAlgorithmsExtension([]uint16{uint16(tls.ECDSAWithP256AndSHA256)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ea.AuthenticatorRequest{
|
||||
Type: ea.HandshakeTypeClientCertificateRequest,
|
||||
Context: ctx,
|
||||
Extensions: []ea.Extension{
|
||||
sigExt,
|
||||
ea.CMWAttestationOfferExtension(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func resolveIdentity(cfg *ServerConfig) (tls.Certificate, error) {
|
||||
if len(cfg.Identity.Certificate) > 0 && cfg.Identity.PrivateKey != nil {
|
||||
return cfg.Identity, nil
|
||||
}
|
||||
if cfg.TLSConfig != nil && len(cfg.TLSConfig.Certificates) > 0 {
|
||||
return cfg.TLSConfig.Certificates[0], nil
|
||||
}
|
||||
return tls.Certificate{}, fmt.Errorf("atls: missing server identity")
|
||||
}
|
||||
|
||||
func buildServerExtensions(cfg *ServerConfig, st *tls.ConnectionState, req *ea.AuthenticatorRequest, identity tls.Certificate) ([]ea.Extension, error) {
|
||||
if cfg.BuildLeafExtensions == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if len(identity.Certificate) == 0 {
|
||||
return nil, fmt.Errorf("atls: missing server leaf certificate")
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(identity.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cfg.BuildLeafExtensions(st, req, leaf)
|
||||
}
|
||||
|
||||
func withConnContext(ctx context.Context, conn interface{ SetDeadline(time.Time) error }, fn func() error) error {
|
||||
if ctx == nil {
|
||||
return fn()
|
||||
}
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.SetDeadline(time.Time{})
|
||||
}()
|
||||
} else {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = conn.SetDeadline(time.Now())
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
close(done)
|
||||
_ = conn.SetDeadline(time.Time{})
|
||||
}()
|
||||
}
|
||||
|
||||
err := fn()
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return ctxErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package internaltransport
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func selfSignedCert(t *testing.T) tls.Certificate {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "internal-transport"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerAllowsIdentityWithoutTLSConfig(t *testing.T) {
|
||||
cert := selfSignedCert(t)
|
||||
a, b := net.Pipe()
|
||||
|
||||
serverTLS := tls.Server(a, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
})
|
||||
clientTLS := tls.Client(b, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
})
|
||||
|
||||
type result struct {
|
||||
conn *Conn
|
||||
err error
|
||||
}
|
||||
serverCh := make(chan result, 1)
|
||||
clientCh := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := Server(serverTLS, &ServerConfig{
|
||||
Identity: cert,
|
||||
})
|
||||
serverCh <- result{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
conn, err := Client(clientTLS, &ClientConfig{
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
},
|
||||
})
|
||||
clientCh <- result{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
srvRes := <-serverCh
|
||||
cliRes := <-clientCh
|
||||
|
||||
if srvRes.err != nil {
|
||||
t.Fatalf("server failed: %v", srvRes.err)
|
||||
}
|
||||
if cliRes.err != nil {
|
||||
t.Fatalf("client failed: %v", cliRes.err)
|
||||
}
|
||||
|
||||
defer srvRes.conn.Close()
|
||||
defer cliRes.conn.Close()
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package internaltransport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
raw net.Listener
|
||||
cfg *ServerConfig
|
||||
}
|
||||
|
||||
func Listen(network, address string, cfg *ServerConfig) (*Listener, error) {
|
||||
if cfg == nil || cfg.TLSConfig == nil {
|
||||
return nil, fmt.Errorf("atls: missing server TLS config")
|
||||
}
|
||||
raw, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{raw: raw, cfg: cfg}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
rawConn, err := l.raw.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConn := tls.Server(rawConn, l.cfg.TLSConfig.Clone())
|
||||
conn, err := Server(tlsConn, l.cfg)
|
||||
if err != nil {
|
||||
_ = tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
return l.raw.Close()
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.raw.Addr()
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package internaltransport
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
frameTypeRequest uint8 = iota + 1
|
||||
frameTypeAuthenticator
|
||||
maxFramePayloadLen = 16 << 20 // 16 MiB
|
||||
)
|
||||
|
||||
func writeFrame(w io.Writer, typ uint8, payload []byte) error {
|
||||
if len(payload) > maxFramePayloadLen {
|
||||
return fmt.Errorf("atls: frame payload too large: %d", len(payload))
|
||||
}
|
||||
header := make([]byte, 5)
|
||||
header[0] = typ
|
||||
binary.BigEndian.PutUint32(header[1:5], uint32(len(payload)))
|
||||
if err := writeAll(w, header); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
return writeAll(w, payload)
|
||||
}
|
||||
|
||||
func writeAll(w io.Writer, b []byte) error {
|
||||
for len(b) > 0 {
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readFrame(r io.Reader) (uint8, []byte, error) {
|
||||
header := make([]byte, 5)
|
||||
if _, err := io.ReadFull(r, header); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if header[0] == 0 {
|
||||
return 0, nil, fmt.Errorf("atls: invalid frame type")
|
||||
}
|
||||
n := binary.BigEndian.Uint32(header[1:5])
|
||||
if n > maxFramePayloadLen {
|
||||
return 0, nil, fmt.Errorf("atls: frame payload too large: %d", n)
|
||||
}
|
||||
payload := make([]byte, n)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
return header[0], payload, nil
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package internaltransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadFrameRejectsOversizedPayload(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
header := make([]byte, 5)
|
||||
header[0] = frameTypeRequest
|
||||
binary.BigEndian.PutUint32(header[1:5], maxFramePayloadLen+1)
|
||||
if _, err := buf.Write(header); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, _, err := readFrame(&buf)
|
||||
if err == nil {
|
||||
t.Fatal("expected oversized frame error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "frame payload too large") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFrameRejectsOversizedPayload(t *testing.T) {
|
||||
payload := make([]byte, maxFramePayloadLen+1)
|
||||
err := writeFrame(bytes.NewBuffer(nil), frameTypeAuthenticator, payload)
|
||||
if err == nil {
|
||||
t.Fatal("expected oversized frame error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "frame payload too large") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,103 +1,53 @@
|
||||
// Code generated manually for tests.
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/ea"
|
||||
)
|
||||
|
||||
// NewCertificateProvider creates a new instance of CertificateProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewCertificateProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *CertificateProvider {
|
||||
mock := &CertificateProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// CertificateProvider is an autogenerated mock type for the CertificateProvider type
|
||||
type CertificateProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type CertificateProvider_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
func (_m *CertificateProvider) BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error) {
|
||||
ret := _m.Called(st, req, leaf)
|
||||
|
||||
func (_m *CertificateProvider) EXPECT() *CertificateProvider_Expecter {
|
||||
return &CertificateProvider_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// GetCertificate provides a mock function for the type CertificateProvider
|
||||
func (_mock *CertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
ret := _mock.Called(clientHello)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetCertificate")
|
||||
}
|
||||
|
||||
var r0 *tls.Certificate
|
||||
var r0 []ea.Extension
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) (*tls.Certificate, error)); ok {
|
||||
return returnFunc(clientHello)
|
||||
|
||||
if rf, ok := ret.Get(0).(func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) ([]ea.Extension, error)); ok {
|
||||
return rf(st, req, leaf)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) *tls.Certificate); ok {
|
||||
r0 = returnFunc(clientHello)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*tls.Certificate)
|
||||
}
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]ea.Extension)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(*tls.ClientHelloInfo) error); ok {
|
||||
r1 = returnFunc(clientHello)
|
||||
|
||||
if rf, ok := ret.Get(1).(func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) error); ok {
|
||||
r1 = rf(st, req, leaf)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// CertificateProvider_GetCertificate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCertificate'
|
||||
type CertificateProvider_GetCertificate_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
func NewCertificateProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *CertificateProvider {
|
||||
mockProvider := &CertificateProvider{}
|
||||
mockProvider.Mock.Test(t)
|
||||
|
||||
// GetCertificate is a helper method to define mock.On call
|
||||
// - clientHello *tls.ClientHelloInfo
|
||||
func (_e *CertificateProvider_Expecter) GetCertificate(clientHello interface{}) *CertificateProvider_GetCertificate_Call {
|
||||
return &CertificateProvider_GetCertificate_Call{Call: _e.mock.On("GetCertificate", clientHello)}
|
||||
}
|
||||
|
||||
func (_c *CertificateProvider_GetCertificate_Call) Run(run func(clientHello *tls.ClientHelloInfo)) *CertificateProvider_GetCertificate_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 *tls.ClientHelloInfo
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(*tls.ClientHelloInfo)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
mockProvider.AssertExpectations(t)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *CertificateProvider_GetCertificate_Call) Return(certificate *tls.Certificate, err error) *CertificateProvider_GetCertificate_Call {
|
||||
_c.Call.Return(certificate, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *CertificateProvider_GetCertificate_Call) RunAndReturn(run func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)) *CertificateProvider_GetCertificate_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
return mockProvider
|
||||
}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package atls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
|
||||
"github.com/absmach/certs/sdk"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/ea"
|
||||
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
|
||||
cocosattestation "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
attestationclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
)
|
||||
|
||||
// CertificateProvider is kept for compatibility with existing cocos call sites.
|
||||
// In the EA-based implementation it provides the leaf certificate-entry extensions
|
||||
// carried in the exported authenticator instead of generating TLS certificates.
|
||||
type CertificateProvider interface {
|
||||
BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error)
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
attClient attestationclient.Client
|
||||
platformType cocosattestation.PlatformType
|
||||
}
|
||||
|
||||
func NewProvider(attClient attestationclient.Client, platformType cocosattestation.PlatformType, _ string, _ string, _ sdk.SDK) (CertificateProvider, error) {
|
||||
if attClient == nil {
|
||||
return nil, fmt.Errorf("atls: missing attestation client")
|
||||
}
|
||||
if platformType == cocosattestation.NoCC {
|
||||
return nil, fmt.Errorf("atls: confidential computing platform not available")
|
||||
}
|
||||
return &provider{
|
||||
attClient: attClient,
|
||||
platformType: platformType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *provider) BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error) {
|
||||
if st == nil || req == nil || leaf == nil {
|
||||
return nil, fmt.Errorf("atls: missing state, request, or leaf certificate")
|
||||
}
|
||||
exportedValue, aikPubHash, binding, err := eaattestation.ComputeBinding(st, eaattestation.ExporterLabelAttestation, req.Context, leaf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reportData := sha512.Sum512(binding)
|
||||
nonceBytes := sha256.Sum256(exportedValue)
|
||||
var nonce [32]byte
|
||||
copy(nonce[:], nonceBytes[:])
|
||||
|
||||
evidence, err := p.attClient.GetAttestation(context.Background(), reportData, nonce, p.platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("atls: failed to fetch attestation evidence: %w", err)
|
||||
}
|
||||
|
||||
payloadBytes, err := eaattestation.MarshalPayload(eaattestation.Payload{
|
||||
Version: 1,
|
||||
MediaType: "application/eat+cwt",
|
||||
Evidence: evidence,
|
||||
Binder: eaattestation.AttestationBinder{
|
||||
ExporterLabel: eaattestation.ExporterLabelAttestation,
|
||||
AIKPubHash: aikPubHash,
|
||||
Binding: binding,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ext, err := ea.CMWAttestationDataExtension(payloadBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ea.Extension{ext}, nil
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BuildServerTLSConfig prepares the base TLS configuration used by the EA/aTLS
|
||||
// transport. If no certificate/key pair is configured, it falls back to an
|
||||
// ephemeral self-signed identity bound by the exported authenticator.
|
||||
func BuildServerTLSConfig(certFile, keyFile, serverCAFile, clientCAFile string) (*tls.Config, tls.Certificate, bool, error) {
|
||||
if certFile != "" || keyFile != "" {
|
||||
tlsSetup, err := setupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile)
|
||||
if err != nil {
|
||||
return nil, tls.Certificate{}, false, err
|
||||
}
|
||||
tlsSetup.config.MinVersion = tls.VersionTLS13
|
||||
return tlsSetup.config, tlsSetup.config.Certificates[0], tlsSetup.mtls, nil
|
||||
}
|
||||
|
||||
identity, err := generateEphemeralIdentity()
|
||||
if err != nil {
|
||||
return nil, tls.Certificate{}, false, fmt.Errorf("failed to generate ephemeral TLS identity: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
ClientAuth: tls.NoClientCert,
|
||||
Certificates: []tls.Certificate{identity},
|
||||
}
|
||||
|
||||
mtls, err := configureCertificateAuthorities(tlsConfig, serverCAFile, clientCAFile)
|
||||
if err != nil {
|
||||
return nil, tls.Certificate{}, false, err
|
||||
}
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
return tlsConfig, identity, mtls, nil
|
||||
}
|
||||
|
||||
func generateEphemeralIdentity() (tls.Certificate, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
serialLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialLimit)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: "cocos-atls-ephemeral",
|
||||
Organization: []string{"Ultraviolet"},
|
||||
},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
return tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
Leaf: leaf,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type tlsSetupResult struct {
|
||||
config *tls.Config
|
||||
mtls bool
|
||||
}
|
||||
|
||||
func readFileOrData(input string) ([]byte, error) {
|
||||
if len(input) < 1000 && !strings.Contains(input, "\n") {
|
||||
data, err := os.ReadFile(input)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return []byte(input), nil
|
||||
}
|
||||
|
||||
func loadX509KeyPair(certFile, keyFile string) (tls.Certificate, error) {
|
||||
cert, err := readFileOrData(certFile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read cert: %w", err)
|
||||
}
|
||||
|
||||
key, err := readFileOrData(keyFile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read key: %w", err)
|
||||
}
|
||||
|
||||
return tls.X509KeyPair(cert, key)
|
||||
}
|
||||
|
||||
func loadCertFile(certFile string) ([]byte, error) {
|
||||
if certFile == "" {
|
||||
return []byte{}, nil
|
||||
}
|
||||
return readFileOrData(certFile)
|
||||
}
|
||||
|
||||
func configureCertificateAuthorities(tlsConfig *tls.Config, serverCAFile, clientCAFile string) (bool, error) {
|
||||
rootCA, err := loadCertFile(serverCAFile)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load server ca file: %w", err)
|
||||
}
|
||||
if len(rootCA) > 0 {
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
|
||||
return false, fmt.Errorf("failed to append server ca to tls.Config")
|
||||
}
|
||||
}
|
||||
|
||||
clientCA, err := loadCertFile(clientCAFile)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load client ca file: %w", err)
|
||||
}
|
||||
if len(clientCA) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if tlsConfig.ClientCAs == nil {
|
||||
tlsConfig.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
|
||||
return false, fmt.Errorf("failed to append client ca to tls.Config")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func setupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile string) (*tlsSetupResult, error) {
|
||||
certificate, err := loadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load auth certificates: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
ClientAuth: tls.NoClientCert,
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
}
|
||||
|
||||
mtls, err := configureCertificateAuthorities(tlsConfig, serverCAFile, clientCAFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
return &tlsSetupResult{config: tlsConfig, mtls: mtls}, nil
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package atls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/ea"
|
||||
eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation"
|
||||
internaltransport "github.com/ultravioletrs/cocos/pkg/atls/internal_transport"
|
||||
)
|
||||
|
||||
type Conn = internaltransport.Conn
|
||||
|
||||
type Listener = internaltransport.Listener
|
||||
|
||||
type ClientConfig = internaltransport.ClientConfig
|
||||
|
||||
type ServerConfig = internaltransport.ServerConfig
|
||||
|
||||
type AuthenticatorRequest = ea.AuthenticatorRequest
|
||||
|
||||
func Dial(network, address string, cfg *ClientConfig) (*Conn, error) {
|
||||
return internaltransport.Dial(network, address, cfg)
|
||||
}
|
||||
|
||||
func DialContext(ctx context.Context, network, address string, cfg *ClientConfig) (*Conn, error) {
|
||||
return internaltransport.DialContext(ctx, network, address, cfg)
|
||||
}
|
||||
|
||||
func DialWithDialer(dialer *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) {
|
||||
return internaltransport.DialWithDialer(dialer, network, address, cfg)
|
||||
}
|
||||
|
||||
func Client(tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) {
|
||||
return internaltransport.Client(tlsConn, cfg)
|
||||
}
|
||||
|
||||
func Server(tlsConn *tls.Conn, cfg *ServerConfig) (*Conn, error) {
|
||||
return internaltransport.Server(tlsConn, cfg)
|
||||
}
|
||||
|
||||
func Listen(network, address string, cfg *ServerConfig) (*Listener, error) {
|
||||
return internaltransport.Listen(network, address, cfg)
|
||||
}
|
||||
|
||||
func NewRequest(context []byte) (*ea.AuthenticatorRequest, error) {
|
||||
sigExt, err := ea.SignatureAlgorithmsExtension([]uint16{uint16(tls.ECDSAWithP256AndSHA256)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ea.AuthenticatorRequest{
|
||||
Type: ea.HandshakeTypeClientCertificateRequest,
|
||||
Context: context,
|
||||
Extensions: []ea.Extension{
|
||||
sigExt,
|
||||
ea.CMWAttestationOfferExtension(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewRandomRequest(contextLen int) (*ea.AuthenticatorRequest, error) {
|
||||
context, err := ea.NewRandomContext(contextLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRequest(context)
|
||||
}
|
||||
|
||||
func VerifyOptionsFromTLSConfig(cfg *tls.Config) *x509.VerifyOptions {
|
||||
if cfg == nil || cfg.InsecureSkipVerify || cfg.RootCAs == nil {
|
||||
return nil
|
||||
}
|
||||
return &x509.VerifyOptions{Roots: cfg.RootCAs}
|
||||
}
|
||||
|
||||
func VerificationPolicyFromEvidenceVerifier(v eaattestation.EvidenceVerifier) eaattestation.VerificationPolicy {
|
||||
return eaattestation.VerificationPolicy{EvidenceVerifier: v}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVerifyOptionsFromTLSConfig(t *testing.T) {
|
||||
t.Run("nil config", func(t *testing.T) {
|
||||
if got := VerifyOptionsFromTLSConfig(nil); got != nil {
|
||||
t.Fatalf("expected nil verify options, got %#v", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skip verify disables ea chain validation", func(t *testing.T) {
|
||||
got := VerifyOptionsFromTLSConfig(&tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
})
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil verify options for insecure skip verify, got %#v", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing roots disables ea chain validation", func(t *testing.T) {
|
||||
got := VerifyOptionsFromTLSConfig(&tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
})
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil verify options when roots are not configured, got %#v", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("configured roots are propagated", func(t *testing.T) {
|
||||
roots := x509.NewCertPool()
|
||||
got := VerifyOptionsFromTLSConfig(&tls.Config{
|
||||
RootCAs: roots,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
})
|
||||
if got == nil {
|
||||
t.Fatal("expected verify options, got nil")
|
||||
}
|
||||
if got.Roots != roots {
|
||||
t.Fatal("expected verify options to reuse configured root CAs")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewRandomRequest(t *testing.T) {
|
||||
req1, err := NewRandomRequest(32)
|
||||
if err != nil {
|
||||
t.Fatalf("first request failed: %v", err)
|
||||
}
|
||||
req2, err := NewRandomRequest(32)
|
||||
if err != nil {
|
||||
t.Fatalf("second request failed: %v", err)
|
||||
}
|
||||
|
||||
if len(req1.Context) != 32 {
|
||||
t.Fatalf("expected first request context length 32, got %d", len(req1.Context))
|
||||
}
|
||||
if len(req2.Context) != 32 {
|
||||
t.Fatalf("expected second request context length 32, got %d", len(req2.Context))
|
||||
}
|
||||
if len(req1.Extensions) == 0 {
|
||||
t.Fatal("expected first request to carry extensions")
|
||||
}
|
||||
if len(req2.Extensions) == 0 {
|
||||
t.Fatal("expected second request to carry extensions")
|
||||
}
|
||||
if string(req1.Context) == string(req2.Context) {
|
||||
t.Fatal("expected random request contexts to differ")
|
||||
}
|
||||
}
|
||||
+32
-1
@@ -3,11 +3,18 @@
|
||||
|
||||
package clients
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
_ ClientConfiguration = (*AttestedClientConfig)(nil)
|
||||
_ ClientConfiguration = (*StandardClientConfig)(nil)
|
||||
|
||||
ErrInvalidAttestationRequestContext = errors.New("invalid attestation request context")
|
||||
)
|
||||
|
||||
type ClientConfiguration interface {
|
||||
@@ -29,6 +36,13 @@ type AttestedClientConfig struct {
|
||||
AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""`
|
||||
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
|
||||
ProductName string `env:"PRODUCT_NAME" envDefault:"Milan"`
|
||||
// AttestationRequestContextHex, when set, is decoded from hex and used as
|
||||
// the exported authenticator certificate_request_context. This lets the
|
||||
// caller provide the background-check freshness value directly.
|
||||
AttestationRequestContextHex string `env:"ATTESTATION_REQUEST_CONTEXT" envDefault:""`
|
||||
// AttestationRequestContext allows callers inside the same process to pass
|
||||
// raw request-context bytes directly instead of using the hex string form.
|
||||
AttestationRequestContext []byte `env:"-"`
|
||||
}
|
||||
|
||||
func (c AttestedClientConfig) Config() StandardClientConfig {
|
||||
@@ -38,3 +52,20 @@ func (c AttestedClientConfig) Config() StandardClientConfig {
|
||||
func (c StandardClientConfig) Config() StandardClientConfig {
|
||||
return c
|
||||
}
|
||||
|
||||
func (c AttestedClientConfig) RequestContext() ([]byte, error) {
|
||||
if len(c.AttestationRequestContext) > 0 {
|
||||
return append([]byte(nil), c.AttestationRequestContext...), nil
|
||||
}
|
||||
if c.AttestationRequestContextHex == "" {
|
||||
return nil, nil
|
||||
}
|
||||
requestContext, err := hex.DecodeString(c.AttestationRequestContextHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrInvalidAttestationRequestContext, err)
|
||||
}
|
||||
if len(requestContext) == 0 {
|
||||
return nil, fmt.Errorf("%w: decoded value is empty", ErrInvalidAttestationRequestContext)
|
||||
}
|
||||
return requestContext, nil
|
||||
}
|
||||
|
||||
@@ -103,6 +103,22 @@ func TestNewClient(t *testing.T) {
|
||||
wantErr: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Success agent client with aTLS and custom request context",
|
||||
agentCfg: clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
ClientCert: clientCertFile,
|
||||
ClientKey: clientKeyFile,
|
||||
},
|
||||
AttestedTLS: true,
|
||||
AttestationPolicy: policyFile.Name(),
|
||||
AttestationRequestContextHex: "01020304",
|
||||
},
|
||||
wantErr: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Failed agent client with aTLS",
|
||||
agentCfg: clients.AttestedClientConfig{
|
||||
@@ -118,6 +134,22 @@ func TestNewClient(t *testing.T) {
|
||||
wantErr: true,
|
||||
err: fmt.Errorf("failed to stat attestation policy file"),
|
||||
},
|
||||
{
|
||||
name: "Failed agent client with invalid attestation request context",
|
||||
agentCfg: clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
ClientCert: clientCertFile,
|
||||
ClientKey: clientKeyFile,
|
||||
},
|
||||
AttestedTLS: true,
|
||||
AttestationPolicy: policyFile.Name(),
|
||||
AttestationRequestContextHex: "xyz",
|
||||
},
|
||||
wantErr: true,
|
||||
err: clients.ErrInvalidAttestationRequestContext,
|
||||
},
|
||||
{
|
||||
name: "Fail with invalid ServerCAFile",
|
||||
cfg: clients.StandardClientConfig{
|
||||
|
||||
@@ -4,7 +4,13 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
stdtls "crypto/tls"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
@@ -77,7 +83,43 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, e
|
||||
return nil, security, err
|
||||
}
|
||||
|
||||
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(result.Config)))
|
||||
tlsConfig := result.Config.Clone()
|
||||
tlsConfig.MinVersion = stdtls.VersionTLS13
|
||||
tlsConfig.NextProtos = []string{"h2"}
|
||||
|
||||
atlsConfig := &atls.ClientConfig{
|
||||
TLSConfig: tlsConfig,
|
||||
VerifyOptions: atls.VerifyOptionsFromTLSConfig(tlsConfig),
|
||||
AttestationPolicy: atls.VerificationPolicyFromEvidenceVerifier(atls.NewEvidenceVerifier(agcfg.AttestationPolicy)),
|
||||
}
|
||||
requestContext, err := agcfg.RequestContext()
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
if len(requestContext) > 0 {
|
||||
req, err := atls.NewRequest(requestContext)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
atlsConfig.Request = req
|
||||
} else {
|
||||
atlsConfig.RequestBuilder = func() (*atls.AuthenticatorRequest, error) {
|
||||
return atls.NewRandomRequest(32)
|
||||
}
|
||||
}
|
||||
|
||||
opts = append(opts,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
network, target := dialTarget(addr)
|
||||
conn, err := atls.DialContext(ctx, network, target, atlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}),
|
||||
)
|
||||
security = result.Security
|
||||
} else {
|
||||
conf := cfg.Config()
|
||||
@@ -96,6 +138,16 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, e
|
||||
return conn, security, nil
|
||||
}
|
||||
|
||||
func dialTarget(addr string) (string, string) {
|
||||
if strings.HasPrefix(addr, "unix://") {
|
||||
return "unix", strings.TrimPrefix(addr, "unix://")
|
||||
}
|
||||
if strings.HasPrefix(addr, "/") {
|
||||
return "unix", addr
|
||||
}
|
||||
return "tcp", addr
|
||||
}
|
||||
|
||||
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) {
|
||||
result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,9 +4,14 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
stdtls "crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
)
|
||||
@@ -70,7 +75,33 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Secu
|
||||
return nil, security, err
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = result.Config
|
||||
tlsConfig := result.Config.Clone()
|
||||
tlsConfig.MinVersion = stdtls.VersionTLS13
|
||||
atlsConfig := &atls.ClientConfig{
|
||||
TLSConfig: tlsConfig,
|
||||
VerifyOptions: atls.VerifyOptionsFromTLSConfig(tlsConfig),
|
||||
AttestationPolicy: atls.VerificationPolicyFromEvidenceVerifier(atls.NewEvidenceVerifier(agcfg.AttestationPolicy)),
|
||||
}
|
||||
requestContext, err := agcfg.RequestContext()
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
if len(requestContext) > 0 {
|
||||
req, err := atls.NewRequest(requestContext)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
atlsConfig.Request = req
|
||||
} else {
|
||||
atlsConfig.RequestBuilder = func() (*atls.AuthenticatorRequest, error) {
|
||||
return atls.NewRandomRequest(32)
|
||||
}
|
||||
}
|
||||
|
||||
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialNetwork, target := httpDialTarget(network, addr)
|
||||
return atls.DialContext(ctx, dialNetwork, target, atlsConfig)
|
||||
}
|
||||
security = result.Security
|
||||
} else {
|
||||
conf := cfg.Config()
|
||||
@@ -89,3 +120,13 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Secu
|
||||
|
||||
return transport, security, nil
|
||||
}
|
||||
|
||||
func httpDialTarget(network, addr string) (string, string) {
|
||||
if strings.HasPrefix(addr, "unix://") {
|
||||
return "unix", strings.TrimPrefix(addr, "unix://")
|
||||
}
|
||||
if strings.HasPrefix(addr, "/") {
|
||||
return "unix", addr
|
||||
}
|
||||
return network, addr
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -210,6 +211,62 @@ func TestCreateTransport_ATLSError(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "failed to stat attestation policy")
|
||||
}
|
||||
|
||||
func TestCreateTransport_ATLSCustomRequestContext(t *testing.T) {
|
||||
policyFile, err := os.CreateTemp("", "attestation_policy.json")
|
||||
assert.NoError(t, err)
|
||||
_, err = policyFile.WriteString("{}")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, policyFile.Close())
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove(policyFile.Name())
|
||||
})
|
||||
|
||||
config := &clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "https://agent.example.com",
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
AttestationPolicy: policyFile.Name(),
|
||||
AttestedTLS: true,
|
||||
AttestationRequestContextHex: "01020304",
|
||||
}
|
||||
|
||||
transport, security, err := createTransport(config)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, transport)
|
||||
assert.Equal(t, tls.WithATLS, security)
|
||||
assert.NotNil(t, transport.DialTLSContext)
|
||||
}
|
||||
|
||||
func TestCreateTransport_ATLSInvalidRequestContext(t *testing.T) {
|
||||
policyFile, err := os.CreateTemp("", "attestation_policy.json")
|
||||
assert.NoError(t, err)
|
||||
_, err = policyFile.WriteString("{}")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, policyFile.Close())
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove(policyFile.Name())
|
||||
})
|
||||
|
||||
config := &clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "https://agent.example.com",
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
AttestationPolicy: policyFile.Name(),
|
||||
AttestedTLS: true,
|
||||
AttestationRequestContextHex: "xyz",
|
||||
}
|
||||
|
||||
transport, security, err := createTransport(config)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, transport)
|
||||
assert.Equal(t, tls.WithoutTLS, security)
|
||||
assert.Contains(t, err.Error(), "invalid attestation request context")
|
||||
}
|
||||
|
||||
func TestCreateTransport_BasicTLSError(t *testing.T) {
|
||||
config := clients.StandardClientConfig{
|
||||
URL: "https://example.com",
|
||||
|
||||
+58
-28
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
@@ -18,6 +19,8 @@ import (
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
const unix = "unix"
|
||||
|
||||
// ProxyConfig contains configuration for starting a proxy instance.
|
||||
type ProxyConfig struct {
|
||||
Port string
|
||||
@@ -81,12 +84,12 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
var rp *httputil.ReverseProxy
|
||||
|
||||
// Check if backend is Unix socket or TCP
|
||||
if p.backendURL.Scheme == "unix" {
|
||||
if p.backendURL.Scheme == unix {
|
||||
// For Unix socket backend, we need to manually configure the reverse proxy
|
||||
// because NewSingleHostReverseProxy doesn't support unix:// scheme
|
||||
targetURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "unix",
|
||||
Host: unix,
|
||||
}
|
||||
rp = httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
@@ -96,7 +99,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
originalDirector(req)
|
||||
// Set the URL to point to the backend service
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = "unix"
|
||||
req.URL.Host = unix
|
||||
}
|
||||
|
||||
// Configure Transport for Unix socket with HTTP/2
|
||||
@@ -105,7 +108,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
// Use Unix socket path from URL
|
||||
return d.DialContext(ctx, "unix", p.backendURL.Path)
|
||||
return d.DialContext(ctx, unix, p.backendURL.Path)
|
||||
},
|
||||
}
|
||||
} else {
|
||||
@@ -130,6 +133,8 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
|
||||
// Configure TLS
|
||||
var tlsConfig *tls.Config
|
||||
var listener net.Listener
|
||||
var err error
|
||||
contextDesc := fmt.Sprintf("context %s", ctx.ID)
|
||||
if ctx.Name != "" {
|
||||
contextDesc = fmt.Sprintf("%s (%s)", ctx.Name, ctx.ID)
|
||||
@@ -139,23 +144,31 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
if p.certProvider == nil {
|
||||
return fmt.Errorf("attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running")
|
||||
}
|
||||
tlsConfig = &tls.Config{
|
||||
GetCertificate: p.certProvider.GetCertificate,
|
||||
ClientAuth: tls.NoClientCert,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
tlsConfig, identity, mtls, err := atls.BuildServerTLSConfig(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup attested TLS: %w", err)
|
||||
}
|
||||
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||
|
||||
listener, err = p.attestedListener(addr, tlsConfig, identity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
mtls, err := ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
p.started = true
|
||||
go func() {
|
||||
serveErr := p.httpServer.Serve(listener)
|
||||
if serveErr != nil && serveErr != http.ErrServerClosed {
|
||||
p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", serveErr))
|
||||
}
|
||||
}()
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested mTLS for %s", addr, contextDesc))
|
||||
} else {
|
||||
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested TLS for %s", addr, contextDesc))
|
||||
}
|
||||
return nil
|
||||
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
|
||||
// Regular TLS
|
||||
tlsSetup, err := SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
@@ -174,31 +187,48 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s without TLS for %s", addr, contextDesc))
|
||||
}
|
||||
|
||||
if tlsConfig != nil {
|
||||
tcpListener, listenErr := net.Listen("tcp", addr)
|
||||
if listenErr != nil {
|
||||
return fmt.Errorf("failed to listen: %w", listenErr)
|
||||
}
|
||||
listener = tls.NewListener(tcpListener, tlsConfig)
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
p.started = true
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
var err error
|
||||
if tlsConfig != nil {
|
||||
ln, listenErr := net.Listen("tcp", addr)
|
||||
if listenErr != nil {
|
||||
p.logger.Error(fmt.Sprintf("failed to listen: %s", listenErr))
|
||||
return
|
||||
}
|
||||
tlsLn := tls.NewListener(ln, tlsConfig)
|
||||
err = p.httpServer.Serve(tlsLn)
|
||||
} else {
|
||||
err = p.httpServer.ListenAndServe()
|
||||
}
|
||||
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", err))
|
||||
serveErr := p.httpServer.Serve(listener)
|
||||
if serveErr != nil && serveErr != http.ErrServerClosed {
|
||||
p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", serveErr))
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxyServer) attestedListener(addr string, tlsConfig *tls.Config, identity tls.Certificate) (net.Listener, error) {
|
||||
network := "tcp"
|
||||
address := addr
|
||||
if len(addr) > 0 && addr[0] == '/' {
|
||||
network = unix
|
||||
address = addr
|
||||
_ = os.Remove(address)
|
||||
}
|
||||
|
||||
return atls.Listen(network, address, &atls.ServerConfig{
|
||||
TLSConfig: tlsConfig,
|
||||
Identity: identity,
|
||||
BuildLeafExtensions: p.certProvider.BuildLeafExtensions,
|
||||
})
|
||||
}
|
||||
|
||||
// Stop stops the proxy server.
|
||||
func (p *proxyServer) Stop() error {
|
||||
p.mu.Lock()
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -113,6 +114,9 @@ func TestProxyStartWithoutPort(t *testing.T) {
|
||||
ctx := ProxyContext{ID: "test-2"}
|
||||
|
||||
err := ps.Start(cfg, ctx)
|
||||
if err != nil && strings.Contains(err.Error(), "address already in use") {
|
||||
t.Skip("default ingress port 7002 is already in use")
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -141,6 +145,33 @@ func TestProxyStartAlreadyStarted(t *testing.T) {
|
||||
assert.Equal(t, "proxy server already started", err.Error())
|
||||
}
|
||||
|
||||
func TestProxyStartReturnsListenerError(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil).(*proxyServer)
|
||||
|
||||
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer occupied.Close()
|
||||
|
||||
port := occupied.Addr().(*net.TCPAddr).Port
|
||||
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
|
||||
ctx := ProxyContext{ID: "test-bind-failure"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to listen")
|
||||
assert.False(t, ps.started)
|
||||
|
||||
retry, retryErr := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, retryErr)
|
||||
retryPort := retry.Addr().(*net.TCPAddr).Port
|
||||
retry.Close()
|
||||
|
||||
err = ps.Start(ProxyConfig{Port: fmt.Sprintf("%d", retryPort)}, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
}
|
||||
|
||||
// TestProxyStartAfterStopped tests error when starting after stop.
|
||||
func TestProxyStartAfterStopped(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
@@ -473,5 +504,5 @@ func TestProxyAttestedTLSInvalidCA(t *testing.T) {
|
||||
|
||||
err := ps.Start(cfg, ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to configure certificate authorities")
|
||||
assert.Contains(t, err.Error(), "failed to setup attested TLS")
|
||||
}
|
||||
|
||||
+5
-17
@@ -4,15 +4,12 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
@@ -125,21 +122,12 @@ func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey strin
|
||||
security = WithMATLS
|
||||
}
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, errors.Wrap(errors.New("failed to generate nonce"), err)
|
||||
}
|
||||
|
||||
encoded := hex.EncodeToString(nonce)
|
||||
sni := encoded + ".nonce"
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
RootCAs: rootCAs,
|
||||
ServerName: sni,
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return atls.NewCertificateVerifier(rootCAs).VerifyPeerCertificate(rawCerts, verifiedChains, nonce)
|
||||
},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
RootCAs: rootCAs,
|
||||
}
|
||||
if rootCAs == nil {
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
if clientCert != "" || clientKey != "" {
|
||||
|
||||
+6
-4
@@ -6,6 +6,7 @@ package tls
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
stdtls "crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
@@ -296,10 +297,11 @@ func TestLoadATLSConfig(t *testing.T) {
|
||||
assert.NotNil(t, result.Config)
|
||||
|
||||
// Verify TLS config properties
|
||||
assert.True(t, result.Config.InsecureSkipVerify)
|
||||
assert.NotNil(t, result.Config.VerifyPeerCertificate)
|
||||
assert.NotEmpty(t, result.Config.ServerName)
|
||||
assert.Contains(t, result.Config.ServerName, ".nonce")
|
||||
assert.Equal(t, tt.serverCAFile == "", result.Config.InsecureSkipVerify)
|
||||
assert.Equal(t, uint16(stdtls.VersionTLS13), result.Config.MinVersion)
|
||||
if tt.serverCAFile != "" {
|
||||
assert.NotNil(t, result.Config.RootCAs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user