mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-474 - New aTLS implementation (#475)
* initial new aTLS * add CA API call for aTLS
This commit is contained in:
committed by
GitHub
parent
9c8ddfd2b1
commit
698bd948ed
@@ -0,0 +1,333 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/certs"
|
||||
certscli "github.com/absmach/certs/cli"
|
||||
"github.com/absmach/certs/errors"
|
||||
certssdk "github.com/absmach/certs/sdk"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
const (
|
||||
vmpl2 = 2
|
||||
organization = "Ultraviolet"
|
||||
country = "Serbia"
|
||||
province = ""
|
||||
locality = "Belgrade"
|
||||
streetAddress = "Bulevar Arsenija Carnojevica 103"
|
||||
postalCode = "11000"
|
||||
notAfterYear = 1
|
||||
notAfterMonth = 0
|
||||
notAfterDay = 0
|
||||
)
|
||||
|
||||
var (
|
||||
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
|
||||
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
|
||||
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
|
||||
)
|
||||
|
||||
type csrReq struct {
|
||||
CSR string `json:"csr,omitempty"`
|
||||
}
|
||||
|
||||
func getPlatformProvider(platformType attestation.PlatformType) (attestation.Provider, error) {
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
return vtpm.NewProvider(true, vmpl2), nil
|
||||
case attestation.Azure:
|
||||
return azure.NewProvider(), nil
|
||||
case attestation.TDX:
|
||||
return tdx.NewProvider(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
}
|
||||
|
||||
func getPlatformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) {
|
||||
var verifier attestation.Verifier
|
||||
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
verifier = vtpm.NewVerifier(nil)
|
||||
case attestation.Azure:
|
||||
verifier = azure.NewVerifier(nil)
|
||||
case attestation.TDX:
|
||||
verifier = tdx.NewVerifier()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
|
||||
err := verifier.JSONToPolicy(attestation.AttestationPolicyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return verifier, nil
|
||||
}
|
||||
|
||||
func getOID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) {
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
return SNPvTPMOID, nil
|
||||
case attestation.Azure:
|
||||
return AzureOID, nil
|
||||
case attestation.TDX:
|
||||
return TDXOID, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
}
|
||||
|
||||
func GetPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
|
||||
switch {
|
||||
case oid.Equal(SNPvTPMOID):
|
||||
return attestation.SNPvTPM, nil
|
||||
case oid.Equal(AzureOID):
|
||||
return attestation.Azure, nil
|
||||
case oid.Equal(TDXOID):
|
||||
return attestation.TDX, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported OID: %v", oid)
|
||||
}
|
||||
}
|
||||
|
||||
func VerifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, pType attestation.PlatformType) error {
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
verifier, err := getPlatformVerifier(pType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get platform verifier: %w", err)
|
||||
}
|
||||
|
||||
if err = verifier.VerifyAttestation(extension, hashNonce[:], hashNonce[:vtpm.Nonce]); err != nil {
|
||||
fmt.Printf("failed to verify attestation: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetCertificate(caUrl string, cvmId string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
pType := attestation.CCPlatform()
|
||||
|
||||
provider, err := getPlatformProvider(pType)
|
||||
if err != nil {
|
||||
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, fmt.Errorf("failed to get platform provider: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
teeOid, err := getOID(pType)
|
||||
if err != nil {
|
||||
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, fmt.Errorf("failed to get OID for platform type: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
curve := elliptic.P256()
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private/public key: %w", err)
|
||||
}
|
||||
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal public key to DER format: %w", err)
|
||||
}
|
||||
|
||||
sniLength := len(clientHello.ServerName)
|
||||
if sniLength < 7 || clientHello.ServerName[sniLength-6:] != ".nonce" {
|
||||
return nil, fmt.Errorf("invalid server name: %s", clientHello.ServerName)
|
||||
}
|
||||
|
||||
nonceStr := clientHello.ServerName[:sniLength-6]
|
||||
nonce, err := hex.DecodeString(nonceStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode nonce from server name: %w", err)
|
||||
}
|
||||
|
||||
if len(nonce) != 64 {
|
||||
return nil, fmt.Errorf("invalid nonce length: expected 64 bytes, got %d bytes", len(nonce))
|
||||
}
|
||||
|
||||
attestExtension, err := getCertificateExtension(provider, pubKeyDER, nonce, teeOid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get certificate extension: %w", err)
|
||||
}
|
||||
|
||||
var certDERBytes []byte
|
||||
|
||||
if caUrl == "" && cvmId == "" {
|
||||
certTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(202403311),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{organization},
|
||||
Country: []string{country},
|
||||
Province: []string{province},
|
||||
Locality: []string{locality},
|
||||
StreetAddress: []string{streetAddress},
|
||||
PostalCode: []string{postalCode},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
ExtraExtensions: []pkix.Extension{attestExtension},
|
||||
}
|
||||
|
||||
DERBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
certDERBytes = DERBytes
|
||||
} else {
|
||||
csrmd := certs.CSRMetadata{
|
||||
Organization: []string{organization},
|
||||
Country: []string{country},
|
||||
Province: []string{province},
|
||||
Locality: []string{locality},
|
||||
StreetAddress: []string{streetAddress},
|
||||
PostalCode: []string{postalCode},
|
||||
ExtraExtensions: []pkix.Extension{attestExtension},
|
||||
}
|
||||
|
||||
csr, err := certscli.CreateCSR(csrmd, privateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CSR: %w", err)
|
||||
}
|
||||
|
||||
csrData := string(csr.CSR)
|
||||
|
||||
r := csrReq{
|
||||
CSR: csrData,
|
||||
}
|
||||
|
||||
data, sdkErr := json.Marshal(r)
|
||||
if sdkErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal CSR request: %w", sdkErr)
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay)
|
||||
ttlString := notAfter.Sub(notBefore).String()
|
||||
|
||||
query := url.Values{}
|
||||
query.Add("ttl", ttlString)
|
||||
query_string := query.Encode()
|
||||
|
||||
certsEndpoint := "certs"
|
||||
csrEndpoint := "csrs"
|
||||
endpoint := fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, cvmId)
|
||||
|
||||
url := fmt.Sprintf("%s/%s?%s", caUrl, endpoint, query_string)
|
||||
|
||||
_, body, err := processRequest(http.MethodPost, url, data, nil, http.StatusOK)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process request: %w", err)
|
||||
}
|
||||
|
||||
var cert certssdk.Certificate
|
||||
if err := json.Unmarshal(body, &cert); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal certificate response: %w", err)
|
||||
}
|
||||
|
||||
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
|
||||
|
||||
block, rest := pem.Decode([]byte(cleanCertificateString))
|
||||
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("failed to convert generated certificate to DER format: %s", cleanCertificateString)
|
||||
}
|
||||
|
||||
certDERBytes = block.Bytes
|
||||
}
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{certDERBytes},
|
||||
PrivateKey: privateKey,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getCertificateExtension(provider attestation.Provider, pubKey []byte, nonce []byte, teeOid asn1.ObjectIdentifier) (pkix.Extension, error) {
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
rawAttestation, err := provider.Attestation(hashNonce[:], hashNonce[:vtpm.Nonce])
|
||||
if err != nil {
|
||||
return pkix.Extension{}, fmt.Errorf("failed to get attestation: %w", err)
|
||||
}
|
||||
|
||||
return pkix.Extension{
|
||||
Id: teeOid,
|
||||
Value: rawAttestation,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func processRequest(method, reqUrl string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) {
|
||||
req, err := http.NewRequest(method, reqUrl, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
|
||||
// Sets a default value for the Content-Type.
|
||||
// Overridden if Content-Type is passed in the headers arguments.
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
sdkerr := errors.CheckError(resp, expectedRespCodes...)
|
||||
if sdkerr != nil {
|
||||
return make(http.Header), []byte{}, sdkerr
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
return resp.Header, body, nil
|
||||
}
|
||||
@@ -1,453 +0,0 @@
|
||||
// Code generated by cmd/cgo; DO NOT EDIT.
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
// #cgo LDFLAGS: -lssl -lcrypto
|
||||
// #include "extensions.h"
|
||||
// #include <string.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"crypto/sha3"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
const (
|
||||
noError = 0
|
||||
errorZeroReturn = 6
|
||||
errorWantRead = 2
|
||||
errorWantWrite = 3
|
||||
errorSyscall = 5
|
||||
errorSsl = 1
|
||||
waitTime = 2
|
||||
vmpl2 = 2
|
||||
)
|
||||
|
||||
var (
|
||||
errListener = errors.New("listener could not be created")
|
||||
errBadIPFormat = errors.New("bad format of IP address")
|
||||
errCloseTLS = errors.New("could not close TLS connection")
|
||||
errConnFailed = errors.New("tls connection is nil")
|
||||
errWrite = errors.New("could not write to TLS")
|
||||
errTLSConn = errors.New("connection did not close correctly")
|
||||
errReadDeadline = errors.New("could not set read deadline, socket timeout failed")
|
||||
errWriteDeadline = errors.New("could not set write deadline, socket timeout failed")
|
||||
errConnCreate = errors.New("could not create connection")
|
||||
)
|
||||
|
||||
func formTeeData(pubKey []byte, teeNonce []byte) []byte {
|
||||
combined := append(pubKey, teeNonce...)
|
||||
sum := sha3.Sum512(combined)
|
||||
return sum[:]
|
||||
}
|
||||
|
||||
func getPlatformProvider(platformType attestation.PlatformType, pubKey []byte) (attestation.Provider, error) {
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
return vtpm.NewProvider(pubKey, true, vmpl2), nil
|
||||
case attestation.Azure:
|
||||
return azure.NewProvider(), nil
|
||||
case attestation.TDX:
|
||||
return tdx.NewProvider(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
}
|
||||
|
||||
func getPlatformVerifier(platformType attestation.PlatformType, pubKey []byte) (attestation.Verifier, error) {
|
||||
var verifier attestation.Verifier
|
||||
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
verifier = vtpm.NewVerifier(pubKey, nil)
|
||||
case attestation.Azure:
|
||||
verifier = azure.NewVerifier(nil)
|
||||
case attestation.TDX:
|
||||
verifier = tdx.NewVerifier()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
|
||||
err := verifier.JSONToPolicy(attestation.AttestationPolicyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return verifier, nil
|
||||
}
|
||||
|
||||
//export callVerificationValidationCallback
|
||||
func callVerificationValidationCallback(platformType C.int, pubKey *C.uchar, pubKeyLen C.int, attestReport *C.uchar, attestReportSize C.int, teeNonceByte *C.uchar, vTPMNonceByte *C.uchar) C.int {
|
||||
pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen)
|
||||
teeNonceData := C.GoBytes(unsafe.Pointer(teeNonceByte), quoteprovider.Nonce)
|
||||
vTPMNonce := C.GoBytes(unsafe.Pointer(vTPMNonceByte), vtpm.Nonce)
|
||||
pType := attestation.PlatformType(int(platformType))
|
||||
attestationReport := C.GoBytes(unsafe.Pointer(attestReport), attestReportSize)
|
||||
|
||||
teeData := formTeeData(pubKeyCert, teeNonceData)
|
||||
|
||||
verifier, err := getPlatformVerifier(pType, pubKeyCert)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "no attestation provider found for platform type %s", err.Error())
|
||||
return C.int(-1)
|
||||
}
|
||||
|
||||
err = verifier.VerifyAttestation(attestationReport, teeData, vTPMNonce)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "verification callback failed %s", err.Error())
|
||||
return C.int(-1)
|
||||
}
|
||||
|
||||
return C.int(0)
|
||||
}
|
||||
|
||||
//export callFetchAttestationCallback
|
||||
func callFetchAttestationCallback(platformType C.int, pubKey *C.uchar, pubKeyLen C.int, teeNonceByte *C.uchar, vTPMNonceByte *C.uchar, outlen *C.ulong) *C.uchar {
|
||||
pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen)
|
||||
teeNonceData := C.GoBytes(unsafe.Pointer(teeNonceByte), quoteprovider.Nonce)
|
||||
vTPMNonce := C.GoBytes(unsafe.Pointer(vTPMNonceByte), vtpm.Nonce)
|
||||
pType := attestation.PlatformType(int(platformType))
|
||||
|
||||
teeData := formTeeData(pubKeyCert, teeNonceData)
|
||||
|
||||
provider, err := getPlatformProvider(pType, pubKeyCert)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "no attestation provider found for platform type %s", err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
quote, err := provider.Attestation(teeData, vTPMNonce)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "attestation callback returned nil: %s", err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
*outlen = C.ulong(len(quote))
|
||||
resultC := C.malloc(C.size_t(len(quote)))
|
||||
if resultC == nil {
|
||||
fmt.Fprintf(os.Stderr, "could not allocate memory for fetch attestation callback")
|
||||
return nil
|
||||
}
|
||||
|
||||
C.memcpy(resultC, unsafe.Pointer("e[0]), C.size_t(len(quote)))
|
||||
|
||||
return (*C.uchar)(resultC)
|
||||
}
|
||||
|
||||
//export returnCCPlatformType
|
||||
func returnCCPlatformType() int32 {
|
||||
return int32(attestation.CCPlatform())
|
||||
}
|
||||
|
||||
type ATLSServerListener struct {
|
||||
tlsListener *C.tls_server_connection
|
||||
}
|
||||
|
||||
func Listen(addr string, cert []byte, key []byte) (net.Listener, error) {
|
||||
ip, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errListener, err)
|
||||
}
|
||||
|
||||
p, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errBadIPFormat, err)
|
||||
}
|
||||
|
||||
cCertPEM := (*C.char)(unsafe.Pointer(&cert[0]))
|
||||
cKeyPEM := (*C.char)(unsafe.Pointer(&key[0]))
|
||||
cIP := C.CString(ip)
|
||||
defer C.free(unsafe.Pointer(cIP))
|
||||
|
||||
atlsListener := C.start_tls_server(
|
||||
cCertPEM, C.int(len(cert)),
|
||||
cKeyPEM, C.int(len(key)),
|
||||
cIP, C.int(p))
|
||||
if atlsListener == nil {
|
||||
return nil, errors.Wrap(errListener, err)
|
||||
}
|
||||
|
||||
return &ATLSServerListener{tlsListener: atlsListener}, nil
|
||||
}
|
||||
|
||||
// accept implements the Accept method in the [Listener] interface; it
|
||||
// waits for the next call and returns a generic [Conn].
|
||||
func (l *ATLSServerListener) Accept() (net.Conn, error) {
|
||||
conn := C.tls_server_accept(l.tlsListener)
|
||||
if conn == nil {
|
||||
return &ATLSConn{tlsConn: nil}, nil
|
||||
}
|
||||
|
||||
return &ATLSConn{tlsConn: conn}, nil
|
||||
}
|
||||
|
||||
// close stops listening on the TCP address.
|
||||
// already Accepted connections are not closed.
|
||||
func (l *ATLSServerListener) Close() error {
|
||||
ret := C.tls_server_close(l.tlsListener)
|
||||
if ret != 0 {
|
||||
return errCloseTLS
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addr returns the listener's network address, a [*TCPAddr].
|
||||
// the Addr returned is shared by all invocations of Addr, so
|
||||
// do not modify it.
|
||||
func (l *ATLSServerListener) Addr() net.Addr {
|
||||
cIP := C.tls_return_addr(&l.tlsListener.addr)
|
||||
defer C.free(unsafe.Pointer(cIP))
|
||||
|
||||
ip := C.GoString(cIP)
|
||||
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
port := C.tls_return_port(&l.tlsListener.addr)
|
||||
|
||||
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
|
||||
}
|
||||
|
||||
type ATLSConn struct {
|
||||
tlsConn *C.tls_connection
|
||||
fdReadMutex sync.Mutex
|
||||
fdWriteMutex sync.Mutex
|
||||
fdDelayMutex sync.Mutex
|
||||
}
|
||||
|
||||
func (c *ATLSConn) Read(b []byte) (int, error) {
|
||||
c.fdReadMutex.Lock()
|
||||
defer c.fdReadMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return 0, errConnFailed
|
||||
}
|
||||
|
||||
n := int(C.tls_read(c.tlsConn, unsafe.Pointer(&b[0]), C.int(len(b))))
|
||||
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// call the C function SSL_get_error to interpret the error.
|
||||
errCode := int(C.SSL_get_error(c.tlsConn.ssl, C.int(n)))
|
||||
|
||||
// handle specific error codes returned by SSL_get_error.
|
||||
switch errCode {
|
||||
case noError:
|
||||
return n, nil // no error.
|
||||
case errorZeroReturn:
|
||||
fmt.Fprintf(os.Stdout, "Connection closed by peer\n")
|
||||
return 0, io.EOF // connection closed.
|
||||
case errorWantRead:
|
||||
fmt.Fprintf(os.Stderr, "Operation read incomplete, retry later\n")
|
||||
return 0, nil // non-fatal, just retry later.
|
||||
case errorWantWrite:
|
||||
fmt.Fprintf(os.Stderr, "Operation write incomplete, retry later\n")
|
||||
return 0, nil // non-fatal, just retry later.
|
||||
case errorSyscall:
|
||||
fmt.Fprintf(os.Stderr, "I/O error\n")
|
||||
return 0, syscall.ECONNRESET // return connection reset error.
|
||||
case errorSsl:
|
||||
fmt.Fprintf(os.Stderr, "I/O error\n")
|
||||
return 0, syscall.ECONNRESET // return connection reset error.
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "SSL error occurred: %d\n", errCode)
|
||||
return 0, fmt.Errorf("SSL error")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ATLSConn) Write(b []byte) (int, error) {
|
||||
c.fdWriteMutex.Lock()
|
||||
defer c.fdWriteMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return 0, errConnFailed
|
||||
}
|
||||
|
||||
n := int(C.tls_write(c.tlsConn, unsafe.Pointer(&b[0]), C.int(len(b))))
|
||||
if n < 0 {
|
||||
return 0, errWrite
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *ATLSConn) Close() error {
|
||||
c.fdReadMutex.Lock()
|
||||
defer c.fdReadMutex.Unlock()
|
||||
|
||||
c.fdWriteMutex.Lock()
|
||||
defer c.fdWriteMutex.Unlock()
|
||||
|
||||
c.fdDelayMutex.Lock()
|
||||
defer c.fdDelayMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
ret := C.tls_close(c.tlsConn)
|
||||
|
||||
if int(ret) == 0 {
|
||||
c.fdDelayMutex.Unlock()
|
||||
c.fdWriteMutex.Unlock()
|
||||
c.fdReadMutex.Unlock()
|
||||
time.Sleep(waitTime * time.Millisecond)
|
||||
c.fdDelayMutex.Lock()
|
||||
c.fdWriteMutex.Lock()
|
||||
c.fdReadMutex.Lock()
|
||||
} else if int(ret) < 0 {
|
||||
c.tlsConn = nil
|
||||
return errTLSConn
|
||||
} else if int(ret) == 1 {
|
||||
c.tlsConn = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ATLSConn) LocalAddr() net.Addr {
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
cIP := C.tls_return_addr(&c.tlsConn.local_addr)
|
||||
ipLength := C.strlen(cIP)
|
||||
defer C.free(unsafe.Pointer(cIP))
|
||||
|
||||
ip := C.GoStringN(cIP, C.int(ipLength))
|
||||
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
fmt.Println("Invalid IP address")
|
||||
return nil
|
||||
}
|
||||
|
||||
port := C.tls_return_port(&c.tlsConn.local_addr)
|
||||
|
||||
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
|
||||
}
|
||||
|
||||
func (c *ATLSConn) RemoteAddr() net.Addr {
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
cIP := C.tls_return_addr(&c.tlsConn.remote_addr)
|
||||
if cIP == nil {
|
||||
fmt.Println("RemoteAddr error while fetching ip")
|
||||
return nil
|
||||
}
|
||||
|
||||
ipLength := C.strlen(cIP)
|
||||
defer C.free(unsafe.Pointer(cIP))
|
||||
|
||||
ip := C.GoStringN(cIP, C.int(ipLength))
|
||||
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
fmt.Println("Invalid IP address")
|
||||
return nil
|
||||
}
|
||||
|
||||
port := C.tls_return_port(&c.tlsConn.remote_addr)
|
||||
|
||||
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
|
||||
}
|
||||
|
||||
func (c *ATLSConn) SetDeadline(t time.Time) error {
|
||||
c.fdDelayMutex.Lock()
|
||||
defer c.fdDelayMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sec, usec := timeToTimeout(t)
|
||||
if C.set_socket_read_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
|
||||
return errReadDeadline
|
||||
}
|
||||
|
||||
if C.set_socket_write_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
|
||||
return errWriteDeadline
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ATLSConn) SetReadDeadline(t time.Time) error {
|
||||
c.fdDelayMutex.Lock()
|
||||
defer c.fdDelayMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sec, usec := timeToTimeout(t)
|
||||
if C.set_socket_read_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
|
||||
return errReadDeadline
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ATLSConn) SetWriteDeadline(t time.Time) error {
|
||||
c.fdDelayMutex.Lock()
|
||||
defer c.fdDelayMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sec, usec := timeToTimeout(t)
|
||||
if C.set_socket_write_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
|
||||
return errWriteDeadline
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DialTLSClient(hostname string, port int) (net.Conn, error) {
|
||||
cHostName := C.CString(hostname)
|
||||
defer C.free(unsafe.Pointer(cHostName))
|
||||
|
||||
conn := C.new_tls_connection(cHostName, C.int(port))
|
||||
if conn == nil {
|
||||
return nil, errConnCreate
|
||||
}
|
||||
|
||||
return &ATLSConn{tlsConn: conn}, nil
|
||||
}
|
||||
|
||||
func timeToTimeout(t time.Time) (int, int) {
|
||||
if t.IsZero() {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
d := time.Until(t)
|
||||
if d <= 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
seconds := int(d.Seconds())
|
||||
microseconds := int(d.Nanoseconds()/1000) % 1_000_000
|
||||
return seconds, microseconds
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
cert := []byte("dummy_cert")
|
||||
key := []byte("dummy_key")
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
address string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid address",
|
||||
address: "127.0.0.1:8889",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid address format",
|
||||
address: "127.0.0.1",
|
||||
err: errListener,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
l, err := Listen(c.address, cert, key)
|
||||
assert.True(t, errors.Contains(err, c.err))
|
||||
if l != nil {
|
||||
t.Cleanup(func() {
|
||||
err := l.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestATLSServerListener_Accept(t *testing.T) {
|
||||
t.Run("Accepts connection", func(t *testing.T) {
|
||||
listener, err := Listen("127.0.0.1:8887", []byte("dummy_cert"), []byte("dummy_key"))
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := listener.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
conn, err := listener.Accept()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestATLSConn_Read(t *testing.T) {
|
||||
buffer := make([]byte, 1024)
|
||||
|
||||
t.Run("Read with nil connection", func(t *testing.T) {
|
||||
conn := &ATLSConn{tlsConn: nil}
|
||||
_, err := conn.Read(buffer)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, err, errConnFailed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestATLSConn_Write(t *testing.T) {
|
||||
data := []byte("test data")
|
||||
|
||||
t.Run("Write with nil connection", func(t *testing.T) {
|
||||
conn := &ATLSConn{tlsConn: nil}
|
||||
_, err := conn.Write(data)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, err, errConnFailed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestATLSConn_DeadlineFunctions(t *testing.T) {
|
||||
conn := &ATLSConn{}
|
||||
|
||||
t.Run("SetDeadline - valid time", func(t *testing.T) {
|
||||
err := conn.SetDeadline(time.Now().Add(1 * time.Minute))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetReadDeadline - past time", func(t *testing.T) {
|
||||
err := conn.SetReadDeadline(time.Now().Add(-1 * time.Minute))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetWriteDeadline - zero time", func(t *testing.T) {
|
||||
err := conn.SetWriteDeadline(time.Time{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"encoding/asn1"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const sevProductNameMilan = "Milan"
|
||||
|
||||
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
|
||||
|
||||
func TestGetPlatformProvider(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
platformType attestation.PlatformType
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid platform type SNPvTPM",
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid platform type Azure",
|
||||
platformType: attestation.Azure,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid platform type TDX",
|
||||
platformType: attestation.TDX,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid platform type",
|
||||
platformType: 999,
|
||||
expectedError: errors.New("unsupported platform type: 999"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
provider, err := getPlatformProvider(c.platformType)
|
||||
|
||||
if c.expectedError != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, c.expectedError.Error(), err.Error())
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPlatformVerifier(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "policy")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
attestationPB := prepVerifyAttReport(t)
|
||||
err = setAttestationPolicy(attestationPB, tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
platformType attestation.PlatformType
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid platform type SNPvTPM",
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid platform type Azure",
|
||||
platformType: attestation.Azure,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid platform type TDX",
|
||||
platformType: attestation.TDX,
|
||||
expectedError: errors.New("unknown field \"pcr_values\""),
|
||||
},
|
||||
{
|
||||
name: "Invalid platform type",
|
||||
platformType: 999,
|
||||
expectedError: errors.New("unsupported platform type: 999"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
verifier, err := getPlatformVerifier(c.platformType)
|
||||
|
||||
if c.expectedError != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), c.expectedError.Error())
|
||||
assert.Nil(t, verifier)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, verifier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOID(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
platformType attestation.PlatformType
|
||||
expectedOID asn1.ObjectIdentifier
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid platform type SNPvTPM",
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectedOID: SNPvTPMOID,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid platform type Azure",
|
||||
platformType: attestation.Azure,
|
||||
expectedOID: AzureOID,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid platform type TDX",
|
||||
platformType: attestation.TDX,
|
||||
expectedOID: TDXOID,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid platform type",
|
||||
platformType: 999,
|
||||
expectedOID: nil,
|
||||
expectedError: errors.New("unsupported platform type: 999"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
oid, err := getOID(c.platformType)
|
||||
|
||||
if c.expectedError != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, c.expectedError.Error(), err.Error())
|
||||
assert.Nil(t, oid)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.expectedOID, oid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPlatformTypeFromOID(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
oid asn1.ObjectIdentifier
|
||||
expectedType attestation.PlatformType
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid OID for SNPvTPM",
|
||||
oid: SNPvTPMOID,
|
||||
expectedType: attestation.SNPvTPM,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid OID for Azure",
|
||||
oid: AzureOID,
|
||||
expectedType: attestation.Azure,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid OID for TDX",
|
||||
oid: TDXOID,
|
||||
expectedType: attestation.TDX,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid OID",
|
||||
oid: asn1.ObjectIdentifier{1, 2, 3},
|
||||
expectedType: 0,
|
||||
expectedError: errors.New("unsupported OID: 1.2.3"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
pType, err := GetPlatformTypeFromOID(c.oid)
|
||||
|
||||
if c.expectedError != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, c.expectedError.Error(), err.Error())
|
||||
assert.Equal(t, attestation.PlatformType(0), pType)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.expectedType, pType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
|
||||
file, err := os.ReadFile("../../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(file) < abi.ReportSize {
|
||||
file = append(file, make([]byte, abi.ReportSize-len(file))...)
|
||||
}
|
||||
|
||||
rr, err := abi.ReportCertsToProto(file)
|
||||
require.NoError(t, err)
|
||||
|
||||
return rr
|
||||
}
|
||||
|
||||
func setAttestationPolicy(rr *sevsnp.Attestation, policyDirectory string) error {
|
||||
attestationPolicyFile, err := os.ReadFile("../../scripts/attestation_policy/attestation_policy.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unmarshalOptions := protojson.UnmarshalOptions{DiscardUnknown: true}
|
||||
|
||||
err = unmarshalOptions.Unmarshal(attestationPolicyFile, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy.Config.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
|
||||
policy.Config.Policy.FamilyId = rr.Report.FamilyId
|
||||
policy.Config.Policy.ImageId = rr.Report.ImageId
|
||||
policy.Config.Policy.Measurement = rr.Report.Measurement
|
||||
policy.Config.Policy.HostData = rr.Report.HostData
|
||||
policy.Config.Policy.ReportIdMa = rr.Report.ReportIdMa
|
||||
policy.Config.RootOfTrust.ProductLine = sevProductNameMilan
|
||||
|
||||
policyByte, err := vtpm.ConvertPolicyToJSON(&policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policyPath := filepath.Join(policyDirectory, "attestation_policy.json")
|
||||
|
||||
err = os.WriteFile(policyPath, policyByte, 0o644)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
attestation.AttestationPolicyPath = policyPath
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,322 +0,0 @@
|
||||
#include "extensions.h"
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <openssl/sha.h>
|
||||
#include <openssl/rand.h>
|
||||
#include <openssl/x509.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
extern int callVerificationValidationCallback(int platformType, const u_char* pubKey, int pubKeyLen, const u_char* quote, int quoteSize, const u_char* teeNonceByte, const u_char* vTPMNonceByte);
|
||||
extern u_char* callFetchAttestationCallback(int platformType, const u_char* pubKey, int pubKeyLen, const u_char* teeNonceByte, const u_char* vTPMNonceByte, unsigned long* outlen);
|
||||
extern uintptr_t validationVerificationCallback(int teeType);
|
||||
extern uintptr_t getPlatformTypeHandle(int platformType, u_char *teeNonce, u_char *vtpmNonce);
|
||||
extern int returnCCPlatformType();
|
||||
|
||||
int triggerVerificationValidationCallback(int platformType, u_char* pub_key, int pub_key_len, u_char *quote, int quote_size, u_char *tee_nonce, u_char *vtpm_nonce) {
|
||||
if (quote == NULL || vtpm_nonce == NULL || tee_nonce == NULL || pub_key == NULL) {
|
||||
fprintf(stderr, "attestation and noce and public key cannot be NULL\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return callVerificationValidationCallback(platformType, pub_key, pub_key_len, quote, quote_size, tee_nonce, vtpm_nonce);
|
||||
}
|
||||
|
||||
u_char* triggerFetchAttestationCallback(int platformType, u_char* pub_key, int pub_key_len, char *tee_nonce, char *vtpm_nonce, unsigned long *outlen) {
|
||||
if(tee_nonce == NULL || vtpm_nonce == NULL) {
|
||||
fprintf(stderr, "Report data cannot be NULL");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return callFetchAttestationCallback(platformType, pub_key, pub_key_len, tee_nonce, vtpm_nonce, outlen);
|
||||
}
|
||||
|
||||
/*
|
||||
Evidence request extension
|
||||
- Contains a random nonce that goes into the attestation report
|
||||
- Is sent in the ClientHello message
|
||||
*/
|
||||
void evidence_request_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char *out,
|
||||
void *add_arg)
|
||||
{
|
||||
free((void *)out);
|
||||
}
|
||||
|
||||
int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char **out,
|
||||
size_t *outlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *add_arg)
|
||||
{
|
||||
switch (context)
|
||||
{
|
||||
case SSL_EXT_CLIENT_HELLO:
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
|
||||
evidence_request *er = (evidence_request*)malloc(sizeof(evidence_request));
|
||||
|
||||
if (er == NULL) {
|
||||
perror("could not allocate memory");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (ext_data != NULL) {
|
||||
if (RAND_bytes(ext_data->er.vtpm_nonce, NONCE_RANDOM_SIZE) != 1) {
|
||||
perror("could not generate random bytes for vtpm nonce, will use SSL client random");
|
||||
SSL_get_client_random(s, ext_data->er.vtpm_nonce, NONCE_RANDOM_SIZE);
|
||||
}
|
||||
|
||||
if (RAND_bytes(ext_data->er.tee_nonce, REPORT_DATA_SIZE) != 1) {
|
||||
perror("could not generate random bytes for tee nonce, will use SSL client random");
|
||||
SSL_get_client_random(s, ext_data->er.tee_nonce, REPORT_DATA_SIZE);
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "add_arg is NULL\n");
|
||||
free(er);
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
memcpy(er->vtpm_nonce, ext_data->er.vtpm_nonce, NONCE_RANDOM_SIZE);
|
||||
memcpy(er->tee_nonce, ext_data->er.tee_nonce, REPORT_DATA_SIZE);
|
||||
|
||||
*out = (const u_char *)er;
|
||||
*outlen = sizeof(evidence_request);
|
||||
return 1;
|
||||
}
|
||||
case SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS:
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
int32_t *platform_type = (int32_t*)malloc(sizeof(int32_t));
|
||||
|
||||
if (platform_type == NULL) {
|
||||
perror("could not allocate memory");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
*platform_type = returnCCPlatformType();
|
||||
ext_data->platform_type = *platform_type;
|
||||
|
||||
*out = (u_char*)platform_type;
|
||||
*outlen = sizeof(int32_t);
|
||||
} else {
|
||||
fprintf(stderr, "add_arg is NULL\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "bad context\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char *in,
|
||||
size_t inlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *parse_arg)
|
||||
{
|
||||
switch (context)
|
||||
{
|
||||
case SSL_EXT_CLIENT_HELLO:
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
|
||||
evidence_request *er = (evidence_request*)in;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
memcpy(ext_data->er.vtpm_nonce, er->vtpm_nonce, NONCE_RANDOM_SIZE);
|
||||
memcpy(ext_data->er.tee_nonce, er->tee_nonce, REPORT_DATA_SIZE);
|
||||
} else {
|
||||
fprintf(stderr, "parse_arg is NULL\n");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
case SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS:
|
||||
{
|
||||
int *platform_type = (int*)in;
|
||||
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
ext_data->platform_type = *platform_type;
|
||||
} else {
|
||||
fprintf(stderr, "parse_arg is NULL\n");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "bad context\n");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Attestation Certificate extension
|
||||
- Contains the attestation report
|
||||
- The attestation report contains the hash of the nonce, the Public Key of the x.509 Agent certificate, and the vTPM AK
|
||||
*/
|
||||
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char *out,
|
||||
void *add_arg)
|
||||
{
|
||||
free((void *)out);
|
||||
}
|
||||
|
||||
int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char **out,
|
||||
size_t *outlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *add_arg)
|
||||
{
|
||||
switch (context)
|
||||
{
|
||||
case SSL_EXT_CLIENT_HELLO:
|
||||
return 1;
|
||||
case SSL_EXT_TLS1_3_CERTIFICATE:
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
|
||||
if (ext_data != NULL) {
|
||||
u_char *quote;
|
||||
size_t len = 0;
|
||||
EVP_PKEY *pkey = NULL;
|
||||
u_char *pubkey_buf = NULL;
|
||||
int pubkey_len = 0;
|
||||
|
||||
|
||||
if (x != NULL) {
|
||||
pkey = X509_get_pubkey(x);
|
||||
if (pkey == NULL) {
|
||||
fprintf(stderr, "Failed to extract public key from certificate\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
|
||||
if (pubkey_len <= 0) {
|
||||
fprintf(stderr, "Failed to convert public key to DER format\n");
|
||||
EVP_PKEY_free(pkey);
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "agent certificate must be used for aTLS\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
quote = triggerFetchAttestationCallback(ext_data->platform_type, pubkey_buf, pubkey_len, ext_data->er.tee_nonce, ext_data->er.vtpm_nonce, &len);
|
||||
if (quote == NULL) {
|
||||
fprintf(stderr, "attestation report is NULL\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
EVP_PKEY_free(pkey);
|
||||
OPENSSL_free(pubkey_buf);
|
||||
return -1;
|
||||
}
|
||||
|
||||
EVP_PKEY_free(pkey);
|
||||
OPENSSL_free(pubkey_buf);
|
||||
|
||||
*out = quote;
|
||||
*outlen = len;
|
||||
return 1;
|
||||
} else {
|
||||
fprintf(stderr, "add_arg is NULL\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "bad context\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char *in,
|
||||
size_t inlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *parse_arg)
|
||||
{
|
||||
switch (context)
|
||||
{
|
||||
case SSL_EXT_CLIENT_HELLO:
|
||||
return 1;
|
||||
case SSL_EXT_TLS1_3_CERTIFICATE:
|
||||
{
|
||||
if (x != NULL) {
|
||||
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
char *quote = (char*)malloc(inlen*sizeof(char));
|
||||
EVP_PKEY *pkey = NULL;
|
||||
u_char *pubkey_buf = NULL;
|
||||
int pubkey_len = 0;
|
||||
int res = 0;
|
||||
|
||||
if (quote == NULL) {
|
||||
perror("could not allocate memory");
|
||||
return 0;
|
||||
}
|
||||
|
||||
pkey = X509_get_pubkey(x);
|
||||
if (pkey == NULL) {
|
||||
fprintf(stderr, "Failed to extract public key from certificate\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
|
||||
if (pubkey_len <= 0) {
|
||||
fprintf(stderr, "Failed to convert public key to DER format\n");
|
||||
EVP_PKEY_free(pkey);
|
||||
return -1;
|
||||
}
|
||||
memcpy(quote, in, inlen);
|
||||
|
||||
res = triggerVerificationValidationCallback(ext_data->platform_type,
|
||||
pubkey_buf,
|
||||
pubkey_len,
|
||||
quote,
|
||||
inlen,
|
||||
(u_char*)&ext_data->er.tee_nonce,
|
||||
(u_char*)&ext_data->er.vtpm_nonce);
|
||||
free(quote);
|
||||
EVP_PKEY_free(pkey);
|
||||
OPENSSL_free(pubkey_buf);
|
||||
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "verification and validation failed, aborting connection\n");
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "parse_arg is NULL\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1;
|
||||
} else {
|
||||
fprintf(stderr, "agent certificates must be used for aTLS\n");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "bad context\n");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
#ifndef ATLS_EXTENSION_H
|
||||
#define ATLS_EXTENSION_H
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
#include <arpa/inet.h>
|
||||
|
||||
#define EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE 65
|
||||
#define ATTESTATION_CERTIFICATE_EXTENSION_TYPE 66
|
||||
#define REPORT_DATA_SIZE 64
|
||||
#define NONCE_RANDOM_SIZE 32
|
||||
#define TLS_CLIENT_CTX 0
|
||||
#define TLS_SERVER_CTX 1
|
||||
|
||||
typedef struct evidence_request
|
||||
{
|
||||
char vtpm_nonce[NONCE_RANDOM_SIZE];
|
||||
char tee_nonce[REPORT_DATA_SIZE];
|
||||
} evidence_request;
|
||||
|
||||
typedef struct tls_extension_data
|
||||
{
|
||||
int platform_type;
|
||||
evidence_request er;
|
||||
} tls_extension_data;
|
||||
|
||||
typedef struct tls_server_connection
|
||||
{
|
||||
int server_fd;
|
||||
char* cert;
|
||||
int cert_len;
|
||||
char* key;
|
||||
int key_len;
|
||||
struct sockaddr_storage addr;
|
||||
} tls_server_connection;
|
||||
|
||||
typedef struct tls_connection
|
||||
{
|
||||
SSL_CTX *ctx;
|
||||
SSL *ssl;
|
||||
int socket_fd;
|
||||
struct sockaddr_storage local_addr;
|
||||
struct sockaddr_storage remote_addr;
|
||||
tls_extension_data tls_ext_data;
|
||||
} tls_connection;
|
||||
|
||||
tls_server_connection* start_tls_server(const char* cert, int cert_len, const char* key, int key_len, const char* ip, int port);
|
||||
tls_connection* tls_server_accept(tls_server_connection *tls_server);
|
||||
int tls_server_close(tls_server_connection *tls_server);
|
||||
int tls_read(tls_connection *conn, void *buf, int num);
|
||||
int tls_write(tls_connection *conn, const void *buf, int num);
|
||||
int tls_close(tls_connection *conn);
|
||||
tls_connection* new_tls_connection(char *address, int port);
|
||||
int set_socket_read_timeout(tls_connection* conn, int timeout_sec, int timeout_usec);
|
||||
int set_socket_write_timeout(tls_connection* conn, int timeout_sec, int timeout_usec);
|
||||
char* tls_return_addr(struct sockaddr_storage *addr);
|
||||
int tls_return_port(struct sockaddr_storage *addr);
|
||||
|
||||
// Extensions
|
||||
void evidence_request_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char *out,
|
||||
void *add_arg);
|
||||
int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char **out,
|
||||
size_t *outlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *add_arg);
|
||||
int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char *in,
|
||||
size_t inlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *parse_arg);
|
||||
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char *out,
|
||||
void *add_arg);
|
||||
int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char **out,
|
||||
size_t *outlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *add_arg);
|
||||
int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char *in,
|
||||
size_t inlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *parse_arg);
|
||||
|
||||
#endif // ATLS_EXTENSION_H
|
||||
@@ -1,607 +0,0 @@
|
||||
#include "extensions.h"
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/x509.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <stdio.h>
|
||||
#include <netdb.h>
|
||||
#include <unistd.h>
|
||||
#include <string.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/time.h>
|
||||
#include <ifaddrs.h>
|
||||
|
||||
void init_openssl() {
|
||||
SSL_load_error_strings();
|
||||
OpenSSL_add_ssl_algorithms();
|
||||
}
|
||||
|
||||
void cleanup_openssl() {
|
||||
EVP_cleanup();
|
||||
}
|
||||
|
||||
int load_certificates_from_memory(SSL_CTX* ctx, const char* cert, int cert_len, const char* key, int key_len) {
|
||||
BIO* cert_bio = NULL;
|
||||
BIO* key_bio = NULL;
|
||||
X509* x509_cert = NULL;
|
||||
EVP_PKEY* pkey = NULL;
|
||||
int success = 0;
|
||||
|
||||
cert_bio = BIO_new_mem_buf(cert, cert_len);
|
||||
if (cert_bio == NULL) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
key_bio = BIO_new_mem_buf(key, key_len);
|
||||
if (key_bio == NULL) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
x509_cert = PEM_read_bio_X509(cert_bio, NULL, 0, NULL);
|
||||
if (x509_cert == NULL) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
pkey = PEM_read_bio_PrivateKey(key_bio, NULL, 0, NULL);
|
||||
if (pkey == NULL) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (SSL_CTX_use_certificate(ctx, x509_cert) <= 0) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (SSL_CTX_use_PrivateKey(ctx, pkey) <= 0) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (SSL_CTX_check_private_key(ctx) <= 0) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
success = 1;
|
||||
|
||||
cleanup:
|
||||
if (cert_bio) BIO_free(cert_bio);
|
||||
if (key_bio) BIO_free(key_bio);
|
||||
if (x509_cert) X509_free(x509_cert);
|
||||
if (pkey) EVP_PKEY_free(pkey);
|
||||
|
||||
if (!success) {
|
||||
ERR_print_errors_fp(stderr);
|
||||
}
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
int enforce_tls1_3_only(SSL_CTX *ctx) {
|
||||
if (SSL_CTX_set_min_proto_version(ctx, TLS1_3_VERSION) == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION) == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
SSL_CTX *create_context(int is_server) {
|
||||
const SSL_METHOD *method;
|
||||
SSL_CTX *ctx;
|
||||
|
||||
if (is_server) {
|
||||
method = TLS_server_method();
|
||||
} else {
|
||||
method = TLS_client_method();
|
||||
}
|
||||
|
||||
ctx = SSL_CTX_new(method);
|
||||
if (!ctx) {
|
||||
perror("Unable to create SSL context");
|
||||
ERR_print_errors_fp(stderr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!enforce_tls1_3_only(ctx)) {
|
||||
fprintf(stderr, "could not enforce TLS1.3\n");
|
||||
SSL_CTX_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
int add_custom_tls_extension(SSL_CTX *ctx, tls_connection *conn) {
|
||||
uint32_t flags_nonce = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS;
|
||||
uint32_t flags_attestation = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_CERTIFICATE;
|
||||
int ret = 1;
|
||||
void *data = NULL;
|
||||
|
||||
if (conn != NULL) {
|
||||
data = (void*)&conn->tls_ext_data;
|
||||
}
|
||||
|
||||
ret = SSL_CTX_add_custom_ext(ctx,
|
||||
EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE,
|
||||
flags_nonce,
|
||||
evidence_request_ext_add_cb,
|
||||
evidence_request_ext_free_cb,
|
||||
data,
|
||||
evidence_request_ext_parse_cb,
|
||||
data);
|
||||
if (ret != 1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
ret = SSL_CTX_add_custom_ext(ctx,
|
||||
ATTESTATION_CERTIFICATE_EXTENSION_TYPE,
|
||||
flags_attestation,
|
||||
attestation_certificate_ext_add_cb,
|
||||
attestation_certificate_ext_free_cb,
|
||||
data,
|
||||
attestation_certificate_ext_parse_cb,
|
||||
data);
|
||||
if (ret != 1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Function to start the TLS server
|
||||
tls_server_connection* start_tls_server(const char* cert, int cert_len, const char* key, int key_len, const char* ip, int port) {
|
||||
tls_server_connection *tls_server = (tls_server_connection*)malloc(sizeof(tls_server_connection));
|
||||
int opt = 0;
|
||||
|
||||
if (tls_server == NULL) {
|
||||
perror("memory could not be allocated");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
init_openssl();
|
||||
|
||||
tls_server->cert = (char*)malloc(cert_len * sizeof(char));
|
||||
if (tls_server->cert == NULL) {
|
||||
perror("memory could not be allocated");
|
||||
goto cleanup_tls_server;
|
||||
}
|
||||
|
||||
tls_server->key = (char*)malloc(key_len * sizeof(char));
|
||||
if (tls_server->key == NULL) {
|
||||
perror("memory could not be allocated");
|
||||
goto cleanup_cert;
|
||||
}
|
||||
|
||||
memcpy(tls_server->cert, cert, cert_len);
|
||||
memcpy(tls_server->key, key, key_len);
|
||||
tls_server->cert_len = cert_len;
|
||||
tls_server->key_len = key_len;
|
||||
|
||||
tls_server->server_fd = socket(AF_INET6, SOCK_STREAM, 0);
|
||||
if (tls_server->server_fd < 0) {
|
||||
perror("Unable to create socket");
|
||||
goto cleanup_key;
|
||||
}
|
||||
|
||||
// Enable both IPv4-mapped and IPv6 addresses
|
||||
if (setsockopt(tls_server->server_fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) != 0) {
|
||||
perror("setsockopt(IPV6_V6ONLY) failed");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
// Configure address structure
|
||||
memset(&(tls_server->addr), 0, sizeof(tls_server->addr));
|
||||
struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)&tls_server->addr;
|
||||
addr6->sin6_family = AF_INET6;
|
||||
addr6->sin6_port = htons(port);
|
||||
|
||||
// Set the appropriate address (IPv4-mapped if needed)
|
||||
if (ip == NULL || (strlen(ip) + 1) < INET_ADDRSTRLEN) {
|
||||
addr6->sin6_addr = in6addr_any;
|
||||
} else if (strchr(ip, ':') != NULL) {
|
||||
if (inet_pton(AF_INET6, ip, &(addr6->sin6_addr)) <= 0) {
|
||||
perror("Invalid IPv6 address");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
} else {
|
||||
struct in_addr ipv4_addr;
|
||||
if (inet_pton(AF_INET, ip, &ipv4_addr) <= 0) {
|
||||
perror("Invalid IPv4 address");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
memset(&addr6->sin6_addr, 0, sizeof(addr6->sin6_addr));
|
||||
addr6->sin6_addr.s6_addr[10] = 0xff;
|
||||
addr6->sin6_addr.s6_addr[11] = 0xff;
|
||||
memcpy(&addr6->sin6_addr.s6_addr[12], &ipv4_addr, sizeof(ipv4_addr));
|
||||
}
|
||||
|
||||
if (bind(tls_server->server_fd, (struct sockaddr*)&(tls_server->addr), sizeof(tls_server->addr)) < 0) {
|
||||
perror("Unable to bind");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
if (listen(tls_server->server_fd, SOMAXCONN) < 0) {
|
||||
perror("Unable to listen");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
printf("Listening on port: %d\n", port);
|
||||
return tls_server;
|
||||
|
||||
// Cleanup labels
|
||||
cleanup_socket:
|
||||
close(tls_server->server_fd);
|
||||
cleanup_key:
|
||||
free(tls_server->key);
|
||||
cleanup_cert:
|
||||
free(tls_server->cert);
|
||||
cleanup_tls_server:
|
||||
free(tls_server);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Function to accept a client connection
|
||||
tls_connection* tls_server_accept(tls_server_connection *tls_server) {
|
||||
uint32_t len = sizeof(struct sockaddr_storage);
|
||||
tls_connection *conn = (tls_connection*)malloc(sizeof(tls_connection));
|
||||
int client_fd = -1;
|
||||
int ret = 0;
|
||||
|
||||
if (conn == NULL) {
|
||||
perror("Unable to allocate memory for tls_connection");
|
||||
return NULL;
|
||||
}
|
||||
conn->ctx = NULL;
|
||||
conn->ssl = NULL;
|
||||
conn->socket_fd = -1;
|
||||
|
||||
// Initialize the context
|
||||
conn->ctx = create_context(TLS_SERVER_CTX);
|
||||
if (conn->ctx == NULL) {
|
||||
perror("Unable to create context");
|
||||
goto cleanup_conn;
|
||||
}
|
||||
|
||||
// Load certificates
|
||||
if (!load_certificates_from_memory(conn->ctx, tls_server->cert, tls_server->cert_len, tls_server->key, tls_server->key_len)) {
|
||||
fprintf(stderr, "Failed to load certificates\n");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Add custom TLS extension
|
||||
ret = add_custom_tls_extension(conn->ctx, conn);
|
||||
if (!ret) {
|
||||
perror("Unable to add custom tls extensions");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Accept client connection
|
||||
client_fd = accept(tls_server->server_fd, NULL, NULL);
|
||||
if (client_fd < 0) {
|
||||
perror("Unable to accept connection");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Create SSL object
|
||||
conn->ssl = SSL_new(conn->ctx);
|
||||
if (conn->ssl == NULL) {
|
||||
perror("Unable to create SSL object");
|
||||
goto cleanup_fd;
|
||||
}
|
||||
|
||||
// Set file descriptor
|
||||
conn->socket_fd = client_fd;
|
||||
SSL_set_fd(conn->ssl, client_fd);
|
||||
|
||||
// Get local address
|
||||
if (getsockname(client_fd, (struct sockaddr *)&conn->local_addr, &len) == -1) {
|
||||
perror("getsockname failed during TLS server accept");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
|
||||
// Get remote address
|
||||
if (getpeername(client_fd, (struct sockaddr *)&conn->remote_addr, &len) == -1) {
|
||||
perror("getpeername failed during TLS server accept");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
|
||||
// Perform SSL handshake
|
||||
ret = SSL_accept(conn->ssl);
|
||||
if (ret <= 0) {
|
||||
perror("SSL handshake failed during accept");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
|
||||
return conn;
|
||||
|
||||
cleanup_ssl:
|
||||
if (conn->ssl) SSL_free(conn->ssl);
|
||||
cleanup_fd:
|
||||
if (client_fd >= 0) close(client_fd);
|
||||
cleanup_ctx:
|
||||
if (conn->ctx) SSL_CTX_free(conn->ctx);
|
||||
cleanup_conn:
|
||||
free(conn);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Function to close the server
|
||||
int tls_server_close(tls_server_connection *tls_server) {
|
||||
close(tls_server->server_fd);
|
||||
cleanup_openssl();
|
||||
free(tls_server->cert);
|
||||
free(tls_server->key);
|
||||
free(tls_server);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tls_read(tls_connection *conn, void *buf, int num) {
|
||||
return SSL_read(conn->ssl, buf, num);
|
||||
}
|
||||
|
||||
int tls_write(tls_connection *conn, const void *buf, int num) {
|
||||
if (SSL_get_shutdown(conn->ssl) & SSL_SENT_SHUTDOWN) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return SSL_write(conn->ssl, buf, num);
|
||||
}
|
||||
|
||||
int tls_close(tls_connection *conn) {
|
||||
if (conn != NULL) {
|
||||
if (conn->ssl != NULL) {
|
||||
int ret = 0;
|
||||
|
||||
if (SSL_has_pending(conn->ssl) == 1 || (SSL_get_shutdown(conn->ssl) & SSL_SENT_SHUTDOWN)) {
|
||||
int num = SSL_pending(conn->ssl);
|
||||
char c[num];
|
||||
int res = 0;
|
||||
int end = 0;
|
||||
|
||||
res = SSL_read(conn->ssl, (void*)c, num);
|
||||
res = SSL_get_error(conn->ssl, res);
|
||||
|
||||
if (res == SSL_ERROR_ZERO_RETURN) {
|
||||
end = 1;
|
||||
} else if (res != SSL_ERROR_NONE) {
|
||||
fprintf(stderr, "SSL_read failed in TLS close call\n");
|
||||
end = 1;
|
||||
}
|
||||
|
||||
if ((SSL_get_shutdown(conn->ssl) & SSL_RECEIVED_SHUTDOWN) || end == 1) {
|
||||
ret = SSL_shutdown(conn->ssl);
|
||||
}
|
||||
} else {
|
||||
ret = SSL_shutdown(conn->ssl);
|
||||
}
|
||||
|
||||
if (ret < 0) {
|
||||
ret = SSL_get_error(conn->ssl, ret);
|
||||
fprintf(stderr, "SSL did not shutdown correctly, error code: %d\n", ret);
|
||||
free(conn);
|
||||
close(conn->socket_fd);
|
||||
conn = NULL;
|
||||
return -1;
|
||||
} else if (ret == 0) {
|
||||
return 0;
|
||||
}
|
||||
conn->ssl = NULL;
|
||||
}
|
||||
if (conn->socket_fd >= 0) {
|
||||
close(conn->socket_fd);
|
||||
conn->socket_fd = -1;
|
||||
}
|
||||
SSL_free(conn->ssl);
|
||||
if (conn->ctx != NULL) {
|
||||
SSL_CTX_free(conn->ctx);
|
||||
conn->ctx = NULL;
|
||||
}
|
||||
|
||||
free(conn);
|
||||
conn = NULL;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
char* tls_return_addr(struct sockaddr_storage *addr) {
|
||||
socklen_t addr_len = sizeof(struct sockaddr_storage);
|
||||
int inet_size =addr->ss_family == AF_INET ? INET_ADDRSTRLEN : INET6_ADDRSTRLEN;
|
||||
char *ip_str = (char*)malloc(inet_size*sizeof(char));
|
||||
void * addr_ptr;
|
||||
|
||||
if (ip_str == NULL) {
|
||||
perror("memory could not be allocated");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (addr->ss_family == AF_INET) {
|
||||
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
|
||||
addr_ptr = &(ipv4->sin_addr);
|
||||
} else if (addr->ss_family == AF_INET6) {
|
||||
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
|
||||
addr_ptr = &(ipv6->sin6_addr);
|
||||
} else {
|
||||
fprintf(stderr, "unknown family: %d\n", addr->ss_family);
|
||||
free(ip_str);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (inet_ntop(addr->ss_family, addr_ptr, ip_str, inet_size) == NULL) {
|
||||
perror("inet_ntop failed");
|
||||
free(ip_str);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return ip_str;
|
||||
}
|
||||
|
||||
int tls_return_port(struct sockaddr_storage *addr) {
|
||||
if (addr->ss_family == AF_INET) {
|
||||
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
|
||||
return ntohs(ipv4->sin_port);
|
||||
} else if (addr->ss_family == AF_INET6) {
|
||||
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
|
||||
return ntohs(ipv6->sin6_port);
|
||||
}
|
||||
|
||||
fprintf(stderr, "cannot return port from unknown family: %d\n", addr->ss_family);
|
||||
return -1;
|
||||
}
|
||||
|
||||
tls_connection* new_tls_connection(char *address, int port) {
|
||||
SSL_CTX *ctx = NULL;
|
||||
SSL *ssl = NULL;
|
||||
int socket_fd = -1;
|
||||
int status;
|
||||
struct addrinfo hints, *res = NULL, *p = NULL;
|
||||
char port_str[6];
|
||||
tls_connection *conn = NULL;
|
||||
socklen_t addr_len;
|
||||
int ret = 0;
|
||||
|
||||
conn = (tls_connection*)malloc(sizeof(tls_connection));
|
||||
if (!conn) {
|
||||
perror("Failed to allocate memory for atls connection");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Format the port string
|
||||
snprintf(port_str, sizeof(port_str), "%d", port);
|
||||
|
||||
// Initialize OpenSSL
|
||||
init_openssl();
|
||||
|
||||
// Create SSL context
|
||||
ctx = create_context(TLS_CLIENT_CTX);
|
||||
if (!ctx) {
|
||||
perror("Could not create context");
|
||||
goto cleanup_conn;
|
||||
}
|
||||
|
||||
conn->ctx = ctx;
|
||||
// Add custom TLS extension
|
||||
ret = add_custom_tls_extension(conn->ctx, conn);
|
||||
if (!ret) {
|
||||
perror("Unable to add custom tls extensions");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Prepare the address info hints
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
hints.ai_flags = 0;
|
||||
hints.ai_protocol = 0;
|
||||
|
||||
// Get address info
|
||||
status = getaddrinfo(address, port_str, &hints, &res);
|
||||
if (status != 0) {
|
||||
perror("getaddrinfo error");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Iterate through the results and try to connect
|
||||
for (p = res; p != NULL; p = p->ai_next) {
|
||||
socket_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
|
||||
if (socket_fd < 0) {
|
||||
perror("unable to create socket");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (connect(socket_fd, p->ai_addr, p->ai_addrlen)) {
|
||||
close(socket_fd);
|
||||
continue;
|
||||
}
|
||||
|
||||
memcpy(&conn->local_addr, p->ai_addr, p->ai_addrlen);
|
||||
break;
|
||||
}
|
||||
|
||||
freeaddrinfo(res);
|
||||
|
||||
if (p == NULL) {
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
conn->socket_fd = socket_fd;
|
||||
|
||||
// Retrieve and store the remote address
|
||||
addr_len = sizeof(conn->remote_addr);
|
||||
if (getpeername(socket_fd, (struct sockaddr *)&conn->remote_addr, &addr_len) == -1) {
|
||||
perror("getpeername failed");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
// Create the SSL structure
|
||||
ssl = SSL_new(ctx);
|
||||
if (!ssl) {
|
||||
perror("Failed to create SSL object");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
if (!SSL_set_fd(ssl, socket_fd)) {
|
||||
perror("Failed to set SSL file descriptor");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
conn->ssl = ssl;
|
||||
|
||||
// Perform the SSL handshake
|
||||
if (SSL_connect(ssl) <= 0) {
|
||||
fprintf(stderr, "SSL handshake failed\n");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
|
||||
return conn;
|
||||
|
||||
cleanup_ssl:
|
||||
if (ssl) SSL_free(ssl);
|
||||
cleanup_socket:
|
||||
if (socket_fd >= 0) close(socket_fd);
|
||||
cleanup_ctx:
|
||||
if (ctx) SSL_CTX_free(ctx);
|
||||
cleanup_conn:
|
||||
free(conn);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int set_socket_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
|
||||
struct timeval timeout;
|
||||
timeout.tv_sec = timeout_sec;
|
||||
timeout.tv_usec = timeout_usec;
|
||||
|
||||
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int set_socket_read_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
|
||||
struct timeval timeout;
|
||||
timeout.tv_sec = timeout_sec;
|
||||
timeout.tv_usec = timeout_usec;
|
||||
|
||||
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int set_socket_write_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
|
||||
struct timeval timeout;
|
||||
timeout.tv_sec = timeout_sec;
|
||||
timeout.tv_usec = timeout_usec;
|
||||
|
||||
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func (a verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
}
|
||||
|
||||
func (a verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
return vtpm.VerifyQuote(report, nil, vTpmNonce, a.writer, a.Policy)
|
||||
return vtpm.VerifyQuote(report, vTpmNonce, a.writer, a.Policy)
|
||||
}
|
||||
|
||||
func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
|
||||
@@ -6,8 +6,6 @@ package vtpm
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -105,14 +103,12 @@ func ExtendPCR(pcrIndex int, value []byte) error {
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
pubKey []byte
|
||||
teeAttestaion bool
|
||||
vmpl uint
|
||||
}
|
||||
|
||||
func NewProvider(pubKey []byte, teeAttestation bool, vmpl uint) attestation.Provider {
|
||||
func NewProvider(teeAttestation bool, vmpl uint) attestation.Provider {
|
||||
return &provider{
|
||||
pubKey: pubKey,
|
||||
teeAttestaion: teeAttestation,
|
||||
vmpl: vmpl,
|
||||
}
|
||||
@@ -140,19 +136,17 @@ func (v provider) AzureAttestationToken(tokenNonce []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
type verifier struct {
|
||||
pubKey []byte
|
||||
writer io.Writer
|
||||
Policy *attestation.Config
|
||||
}
|
||||
|
||||
func NewVerifier(pubKey []byte, writer io.Writer) attestation.Verifier {
|
||||
func NewVerifier(writer io.Writer) attestation.Verifier {
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
return &verifier{
|
||||
pubKey: pubKey,
|
||||
writer: writer,
|
||||
Policy: policy,
|
||||
}
|
||||
@@ -160,11 +154,10 @@ func NewVerifier(pubKey []byte, writer io.Writer) attestation.Verifier {
|
||||
|
||||
func NewVerifierWithPolicy(pubKey []byte, writer io.Writer, policy *attestation.Config) attestation.Verifier {
|
||||
if policy == nil {
|
||||
return NewVerifier(pubKey, writer)
|
||||
return NewVerifier(writer)
|
||||
}
|
||||
|
||||
return &verifier{
|
||||
pubKey: pubKey,
|
||||
writer: writer,
|
||||
Policy: policy,
|
||||
}
|
||||
@@ -181,11 +174,11 @@ func (v verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
}
|
||||
|
||||
func (v verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
return VerifyQuote(report, v.pubKey, vTpmNonce, v.writer, v.Policy)
|
||||
return VerifyQuote(report, vTpmNonce, v.writer, v.Policy)
|
||||
}
|
||||
|
||||
func (v verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
return VTPMVerify(report, v.pubKey, teeNonce, vTpmNonce, v.writer, v.Policy)
|
||||
return VTPMVerify(report, teeNonce, vTpmNonce, v.writer, v.Policy)
|
||||
}
|
||||
|
||||
func (v verifier) JSONToPolicy(path string) error {
|
||||
@@ -208,8 +201,8 @@ func Attest(teeNonce []byte, vTPMNonce []byte, teeAttestaion bool, vmpl uint) ([
|
||||
return marshalQuote(attestation)
|
||||
}
|
||||
|
||||
func VTPMVerify(quote []byte, pubKeyTLS []byte, teeNonce []byte, vtpmNonce []byte, writer io.Writer, policy *attestation.Config) error {
|
||||
if err := VerifyQuote(quote, pubKeyTLS, vtpmNonce, writer, policy); err != nil {
|
||||
func VTPMVerify(quote []byte, teeNonce []byte, vtpmNonce []byte, writer io.Writer, policy *attestation.Config) error {
|
||||
if err := VerifyQuote(quote, vtpmNonce, writer, policy); err != nil {
|
||||
return fmt.Errorf("failed to verify vTPM quote: %v", err)
|
||||
}
|
||||
|
||||
@@ -227,7 +220,7 @@ func VTPMVerify(quote []byte, pubKeyTLS []byte, teeNonce []byte, vtpmNonce []byt
|
||||
return nil
|
||||
}
|
||||
|
||||
func VerifyQuote(quote []byte, pubKeyTLS []byte, vtpmNonce []byte, writer io.Writer, policy *attestation.Config) error {
|
||||
func VerifyQuote(quote []byte, vtpmNonce []byte, writer io.Writer, policy *attestation.Config) error {
|
||||
attestation := &attest.Attestation{}
|
||||
|
||||
err := proto.Unmarshal(quote, attestation)
|
||||
@@ -251,9 +244,7 @@ func VerifyQuote(quote []byte, pubKeyTLS []byte, vtpmNonce []byte, writer io.Wri
|
||||
return errors.Wrap(fmt.Errorf("failed to verify attestation"), err)
|
||||
}
|
||||
|
||||
s256, s384 := calculatePCRTLSKey(pubKeyTLS)
|
||||
|
||||
if err := checkExpectedPCRValues(attestation, s256, s384, policy); err != nil {
|
||||
if err := checkExpectedPCRValues(attestation, policy); err != nil {
|
||||
return fmt.Errorf("PCR values do not match expected PCR values: %w", err)
|
||||
}
|
||||
|
||||
@@ -330,39 +321,23 @@ func addTEEAttestation(attestation *attest.Attestation, nonce []byte, vmpl uint)
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkExpectedPCRValues(attQuote *attest.Attestation, ePcr256, ePcr384 []byte, policy *attestation.Config) error {
|
||||
func checkExpectedPCRValues(attQuote *attest.Attestation, policy *attestation.Config) error {
|
||||
quotes := attQuote.GetQuotes()
|
||||
for i := range quotes {
|
||||
quote := quotes[i]
|
||||
var pcrMap map[string]string
|
||||
var pcr15 []byte
|
||||
|
||||
switch quote.Pcrs.Hash {
|
||||
case ptpm.HashAlgo_SHA256:
|
||||
pcrMap = policy.PcrConfig.PCRValues.Sha256
|
||||
if ePcr256 == nil {
|
||||
pcr15 = make([]byte, 32)
|
||||
} else {
|
||||
pcr15 = ePcr256
|
||||
}
|
||||
case ptpm.HashAlgo_SHA384:
|
||||
pcrMap = policy.PcrConfig.PCRValues.Sha384
|
||||
if ePcr384 == nil {
|
||||
pcr15 = make([]byte, 48)
|
||||
} else {
|
||||
pcr15 = ePcr384
|
||||
}
|
||||
case ptpm.HashAlgo_SHA1:
|
||||
pcrMap = policy.PcrConfig.PCRValues.Sha1
|
||||
pcr15 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
default:
|
||||
return errors.Wrap(ErrNoHashAlgo, fmt.Errorf("algo: %s", ptpm.HashAlgo_name[int32(quote.Pcrs.Hash)]))
|
||||
}
|
||||
|
||||
pcr15Index := uint32(15)
|
||||
if !bytes.Equal(quote.Pcrs.Pcrs[pcr15Index], pcr15) {
|
||||
return fmt.Errorf("for algo %s PCR[15] expected %s but found %s", ptpm.HashAlgo_name[int32(quote.Pcrs.Hash)], hex.EncodeToString(pcr15), hex.EncodeToString(quote.Pcrs.Pcrs[pcr15Index]))
|
||||
}
|
||||
|
||||
for i, v := range pcrMap {
|
||||
index, err := strconv.ParseInt(i, 10, 32)
|
||||
if err != nil {
|
||||
@@ -380,27 +355,6 @@ func checkExpectedPCRValues(attQuote *attest.Attestation, ePcr256, ePcr384 []byt
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return SHA256 and SHA384 values of the input public key.
|
||||
func calculatePCRTLSKey(pubKey []byte) ([]byte, []byte) {
|
||||
if len(pubKey) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
init256 := make([]byte, Hash256)
|
||||
init384 := make([]byte, Hash384)
|
||||
|
||||
key256 := sha3.Sum256(pubKey)
|
||||
key384 := sha3.Sum384(pubKey)
|
||||
|
||||
pcrValue256 := append(init256, key256[:]...)
|
||||
pcrValue384 := append(init384, key384[:]...)
|
||||
|
||||
newPcr256 := sha256.Sum256(pcrValue256)
|
||||
newPcr384 := sha512.Sum384(pcrValue384)
|
||||
|
||||
return newPcr256[:], newPcr384[:]
|
||||
}
|
||||
|
||||
func getPCRValue(index int, algorithm tpm2.Algorithm) ([]byte, error) {
|
||||
rwc, err := OpenTpm()
|
||||
if err != nil {
|
||||
|
||||
+34
-40
@@ -1,19 +1,16 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build cgo
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
@@ -36,12 +33,9 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, securit
|
||||
|
||||
attestation.AttestationPolicyPath = cfg.AttestationPolicy
|
||||
|
||||
var insecureSkipVerify bool = true
|
||||
var rootCAs *x509.CertPool = nil
|
||||
|
||||
if len(cfg.ServerCAFile) > 0 {
|
||||
insecureSkipVerify = false
|
||||
|
||||
// Read the certificate file
|
||||
certPEM, err := os.ReadFile(cfg.ServerCAFile)
|
||||
if err != nil {
|
||||
@@ -66,11 +60,20 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, securit
|
||||
security = withmaTLS
|
||||
}
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to generate nonce"), err)
|
||||
}
|
||||
|
||||
encoded := hex.EncodeToString(nonce)
|
||||
sni := fmt.Sprintf("%s.nonce", encoded)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
InsecureSkipVerify: true,
|
||||
RootCAs: rootCAs,
|
||||
ServerName: sni,
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return verifyPeerCertificateATLS(rawCerts, verifiedChains, cfg)
|
||||
return verifyPeerCertificateATLS(rawCerts, verifiedChains, nonce, rootCAs)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -85,49 +88,40 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, securit
|
||||
return credentials.NewTLS(tlsConfig), security, nil
|
||||
}
|
||||
|
||||
func CustomDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
ip, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create a custom dialer")
|
||||
}
|
||||
|
||||
p, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad format of IP address: %v", err)
|
||||
}
|
||||
|
||||
conn, err := atls.DialTLSClient(ip, p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create TLS connection")
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func verifyPeerCertificateATLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, cfg AgentClientConfig) error {
|
||||
if len(cfg.ServerCAFile) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyPeerCertificateATLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte, rootCAs *x509.CertPool) error {
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return errors.Wrap(errCertificateParse, err)
|
||||
}
|
||||
|
||||
err = checkIfCertificateSelfSigned(cert)
|
||||
err = checkIfCertificateSigned(cert, rootCAs)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
for _, ext := range cert.Extensions {
|
||||
pType, err := atls.GetPlatformTypeFromOID(ext.Id)
|
||||
if err == nil {
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal public key to DER format: %w", err)
|
||||
}
|
||||
|
||||
return atls.VerifyCertificateExtension(ext.Value, pubKeyDER, nonce, pType)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("attestation extension not found in certificate")
|
||||
}
|
||||
|
||||
func checkIfCertificateSelfSigned(cert *x509.Certificate) error {
|
||||
certPool := x509.NewCertPool()
|
||||
certPool.AddCert(cert)
|
||||
func checkIfCertificateSigned(cert *x509.Certificate, rootCAs *x509.CertPool) error {
|
||||
if rootCAs == nil {
|
||||
rootCAs = x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: certPool,
|
||||
Roots: rootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !cgo
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error) {
|
||||
return nil, fmt.Errorf("aTLS is not supported without CGO. Please rebuild with CGO_ENABLED=1")
|
||||
}
|
||||
@@ -377,7 +377,7 @@ func TestCheckIfCertificateSelfSigned(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkIfCertificateSelfSigned(tt.cert)
|
||||
err := checkIfCertificateSigned(tt.cert, nil)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -148,7 +148,6 @@ func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) {
|
||||
}
|
||||
|
||||
opts = append(opts, grpc.WithTransportCredentials(tc))
|
||||
opts = append(opts, grpc.WithContextDialer(CustomDialer))
|
||||
|
||||
secure = sec
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user