COCOS-474 - New aTLS implementation (#475)

* initial new aTLS

* add CA API call for aTLS
This commit is contained in:
Danko Miladinovic
2025-07-08 14:54:57 +02:00
committed by GitHub
parent 9c8ddfd2b1
commit 698bd948ed
19 changed files with 681 additions and 1941 deletions
+333
View File
@@ -0,0 +1,333 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net/http"
"net/url"
"strings"
"time"
"github.com/absmach/certs"
certscli "github.com/absmach/certs/cli"
"github.com/absmach/certs/errors"
certssdk "github.com/absmach/certs/sdk"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"golang.org/x/crypto/sha3"
)
const (
vmpl2 = 2
organization = "Ultraviolet"
country = "Serbia"
province = ""
locality = "Belgrade"
streetAddress = "Bulevar Arsenija Carnojevica 103"
postalCode = "11000"
notAfterYear = 1
notAfterMonth = 0
notAfterDay = 0
)
var (
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
)
type csrReq struct {
CSR string `json:"csr,omitempty"`
}
func getPlatformProvider(platformType attestation.PlatformType) (attestation.Provider, error) {
switch platformType {
case attestation.SNPvTPM:
return vtpm.NewProvider(true, vmpl2), nil
case attestation.Azure:
return azure.NewProvider(), nil
case attestation.TDX:
return tdx.NewProvider(), nil
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
}
func getPlatformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) {
var verifier attestation.Verifier
switch platformType {
case attestation.SNPvTPM:
verifier = vtpm.NewVerifier(nil)
case attestation.Azure:
verifier = azure.NewVerifier(nil)
case attestation.TDX:
verifier = tdx.NewVerifier()
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
err := verifier.JSONToPolicy(attestation.AttestationPolicyPath)
if err != nil {
return nil, err
}
return verifier, nil
}
func getOID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) {
switch platformType {
case attestation.SNPvTPM:
return SNPvTPMOID, nil
case attestation.Azure:
return AzureOID, nil
case attestation.TDX:
return TDXOID, nil
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
}
func GetPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
switch {
case oid.Equal(SNPvTPMOID):
return attestation.SNPvTPM, nil
case oid.Equal(AzureOID):
return attestation.Azure, nil
case oid.Equal(TDXOID):
return attestation.TDX, nil
default:
return 0, fmt.Errorf("unsupported OID: %v", oid)
}
}
func VerifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, pType attestation.PlatformType) error {
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
verifier, err := getPlatformVerifier(pType)
if err != nil {
return fmt.Errorf("failed to get platform verifier: %w", err)
}
if err = verifier.VerifyAttestation(extension, hashNonce[:], hashNonce[:vtpm.Nonce]); err != nil {
fmt.Printf("failed to verify attestation: %v\n", err)
return err
}
return nil
}
func GetCertificate(caUrl string, cvmId string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
pType := attestation.CCPlatform()
provider, err := getPlatformProvider(pType)
if err != nil {
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, fmt.Errorf("failed to get platform provider: %w", err)
}
}
teeOid, err := getOID(pType)
if err != nil {
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, fmt.Errorf("failed to get OID for platform type: %w", err)
}
}
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
curve := elliptic.P256()
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private/public key: %w", err)
}
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key to DER format: %w", err)
}
sniLength := len(clientHello.ServerName)
if sniLength < 7 || clientHello.ServerName[sniLength-6:] != ".nonce" {
return nil, fmt.Errorf("invalid server name: %s", clientHello.ServerName)
}
nonceStr := clientHello.ServerName[:sniLength-6]
nonce, err := hex.DecodeString(nonceStr)
if err != nil {
return nil, fmt.Errorf("failed to decode nonce from server name: %w", err)
}
if len(nonce) != 64 {
return nil, fmt.Errorf("invalid nonce length: expected 64 bytes, got %d bytes", len(nonce))
}
attestExtension, err := getCertificateExtension(provider, pubKeyDER, nonce, teeOid)
if err != nil {
return nil, fmt.Errorf("failed to get certificate extension: %w", err)
}
var certDERBytes []byte
if caUrl == "" && cvmId == "" {
certTemplate := &x509.Certificate{
SerialNumber: big.NewInt(202403311),
Subject: pkix.Name{
Organization: []string{organization},
Country: []string{country},
Province: []string{province},
Locality: []string{locality},
StreetAddress: []string{streetAddress},
PostalCode: []string{postalCode},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
ExtraExtensions: []pkix.Extension{attestExtension},
}
DERBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, fmt.Errorf("failed to create certificate: %w", err)
}
certDERBytes = DERBytes
} else {
csrmd := certs.CSRMetadata{
Organization: []string{organization},
Country: []string{country},
Province: []string{province},
Locality: []string{locality},
StreetAddress: []string{streetAddress},
PostalCode: []string{postalCode},
ExtraExtensions: []pkix.Extension{attestExtension},
}
csr, err := certscli.CreateCSR(csrmd, privateKey)
if err != nil {
return nil, fmt.Errorf("failed to create CSR: %w", err)
}
csrData := string(csr.CSR)
r := csrReq{
CSR: csrData,
}
data, sdkErr := json.Marshal(r)
if sdkErr != nil {
return nil, fmt.Errorf("failed to marshal CSR request: %w", sdkErr)
}
notBefore := time.Now()
notAfter := time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay)
ttlString := notAfter.Sub(notBefore).String()
query := url.Values{}
query.Add("ttl", ttlString)
query_string := query.Encode()
certsEndpoint := "certs"
csrEndpoint := "csrs"
endpoint := fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, cvmId)
url := fmt.Sprintf("%s/%s?%s", caUrl, endpoint, query_string)
_, body, err := processRequest(http.MethodPost, url, data, nil, http.StatusOK)
if err != nil {
return nil, fmt.Errorf("failed to process request: %w", err)
}
var cert certssdk.Certificate
if err := json.Unmarshal(body, &cert); err != nil {
return nil, fmt.Errorf("failed to unmarshal certificate response: %w", err)
}
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
block, rest := pem.Decode([]byte(cleanCertificateString))
if len(rest) != 0 {
return nil, fmt.Errorf("failed to convert generated certificate to DER format: %s", cleanCertificateString)
}
certDERBytes = block.Bytes
}
return &tls.Certificate{
Certificate: [][]byte{certDERBytes},
PrivateKey: privateKey,
}, nil
}
}
func getCertificateExtension(provider attestation.Provider, pubKey []byte, nonce []byte, teeOid asn1.ObjectIdentifier) (pkix.Extension, error) {
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
rawAttestation, err := provider.Attestation(hashNonce[:], hashNonce[:vtpm.Nonce])
if err != nil {
return pkix.Extension{}, fmt.Errorf("failed to get attestation: %w", err)
}
return pkix.Extension{
Id: teeOid,
Value: rawAttestation,
}, nil
}
func processRequest(method, reqUrl string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) {
req, err := http.NewRequest(method, reqUrl, bytes.NewReader(data))
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
// Sets a default value for the Content-Type.
// Overridden if Content-Type is passed in the headers arguments.
req.Header.Add("Content-Type", "application/json")
for key, value := range headers {
req.Header.Add(key, value)
}
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
resp, err := client.Do(req)
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
defer resp.Body.Close()
sdkerr := errors.CheckError(resp, expectedRespCodes...)
if sdkerr != nil {
return make(http.Header), []byte{}, sdkerr
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
return resp.Header, body, nil
}
-453
View File
@@ -1,453 +0,0 @@
// Code generated by cmd/cgo; DO NOT EDIT.
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
// #cgo LDFLAGS: -lssl -lcrypto
// #include "extensions.h"
// #include <string.h>
import "C"
import (
"crypto/sha3"
"fmt"
"io"
"net"
"os"
"strconv"
"sync"
"syscall"
"time"
"unsafe"
"github.com/absmach/magistrala/pkg/errors"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
)
const (
noError = 0
errorZeroReturn = 6
errorWantRead = 2
errorWantWrite = 3
errorSyscall = 5
errorSsl = 1
waitTime = 2
vmpl2 = 2
)
var (
errListener = errors.New("listener could not be created")
errBadIPFormat = errors.New("bad format of IP address")
errCloseTLS = errors.New("could not close TLS connection")
errConnFailed = errors.New("tls connection is nil")
errWrite = errors.New("could not write to TLS")
errTLSConn = errors.New("connection did not close correctly")
errReadDeadline = errors.New("could not set read deadline, socket timeout failed")
errWriteDeadline = errors.New("could not set write deadline, socket timeout failed")
errConnCreate = errors.New("could not create connection")
)
func formTeeData(pubKey []byte, teeNonce []byte) []byte {
combined := append(pubKey, teeNonce...)
sum := sha3.Sum512(combined)
return sum[:]
}
func getPlatformProvider(platformType attestation.PlatformType, pubKey []byte) (attestation.Provider, error) {
switch platformType {
case attestation.SNPvTPM:
return vtpm.NewProvider(pubKey, true, vmpl2), nil
case attestation.Azure:
return azure.NewProvider(), nil
case attestation.TDX:
return tdx.NewProvider(), nil
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
}
func getPlatformVerifier(platformType attestation.PlatformType, pubKey []byte) (attestation.Verifier, error) {
var verifier attestation.Verifier
switch platformType {
case attestation.SNPvTPM:
verifier = vtpm.NewVerifier(pubKey, nil)
case attestation.Azure:
verifier = azure.NewVerifier(nil)
case attestation.TDX:
verifier = tdx.NewVerifier()
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
err := verifier.JSONToPolicy(attestation.AttestationPolicyPath)
if err != nil {
return nil, err
}
return verifier, nil
}
//export callVerificationValidationCallback
func callVerificationValidationCallback(platformType C.int, pubKey *C.uchar, pubKeyLen C.int, attestReport *C.uchar, attestReportSize C.int, teeNonceByte *C.uchar, vTPMNonceByte *C.uchar) C.int {
pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen)
teeNonceData := C.GoBytes(unsafe.Pointer(teeNonceByte), quoteprovider.Nonce)
vTPMNonce := C.GoBytes(unsafe.Pointer(vTPMNonceByte), vtpm.Nonce)
pType := attestation.PlatformType(int(platformType))
attestationReport := C.GoBytes(unsafe.Pointer(attestReport), attestReportSize)
teeData := formTeeData(pubKeyCert, teeNonceData)
verifier, err := getPlatformVerifier(pType, pubKeyCert)
if err != nil {
fmt.Fprintf(os.Stderr, "no attestation provider found for platform type %s", err.Error())
return C.int(-1)
}
err = verifier.VerifyAttestation(attestationReport, teeData, vTPMNonce)
if err != nil {
fmt.Fprintf(os.Stderr, "verification callback failed %s", err.Error())
return C.int(-1)
}
return C.int(0)
}
//export callFetchAttestationCallback
func callFetchAttestationCallback(platformType C.int, pubKey *C.uchar, pubKeyLen C.int, teeNonceByte *C.uchar, vTPMNonceByte *C.uchar, outlen *C.ulong) *C.uchar {
pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen)
teeNonceData := C.GoBytes(unsafe.Pointer(teeNonceByte), quoteprovider.Nonce)
vTPMNonce := C.GoBytes(unsafe.Pointer(vTPMNonceByte), vtpm.Nonce)
pType := attestation.PlatformType(int(platformType))
teeData := formTeeData(pubKeyCert, teeNonceData)
provider, err := getPlatformProvider(pType, pubKeyCert)
if err != nil {
fmt.Fprintf(os.Stderr, "no attestation provider found for platform type %s", err.Error())
return nil
}
quote, err := provider.Attestation(teeData, vTPMNonce)
if err != nil {
fmt.Fprintf(os.Stderr, "attestation callback returned nil: %s", err.Error())
return nil
}
*outlen = C.ulong(len(quote))
resultC := C.malloc(C.size_t(len(quote)))
if resultC == nil {
fmt.Fprintf(os.Stderr, "could not allocate memory for fetch attestation callback")
return nil
}
C.memcpy(resultC, unsafe.Pointer(&quote[0]), C.size_t(len(quote)))
return (*C.uchar)(resultC)
}
//export returnCCPlatformType
func returnCCPlatformType() int32 {
return int32(attestation.CCPlatform())
}
type ATLSServerListener struct {
tlsListener *C.tls_server_connection
}
func Listen(addr string, cert []byte, key []byte) (net.Listener, error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, errors.Wrap(errListener, err)
}
p, err := strconv.Atoi(port)
if err != nil {
return nil, errors.Wrap(errBadIPFormat, err)
}
cCertPEM := (*C.char)(unsafe.Pointer(&cert[0]))
cKeyPEM := (*C.char)(unsafe.Pointer(&key[0]))
cIP := C.CString(ip)
defer C.free(unsafe.Pointer(cIP))
atlsListener := C.start_tls_server(
cCertPEM, C.int(len(cert)),
cKeyPEM, C.int(len(key)),
cIP, C.int(p))
if atlsListener == nil {
return nil, errors.Wrap(errListener, err)
}
return &ATLSServerListener{tlsListener: atlsListener}, nil
}
// accept implements the Accept method in the [Listener] interface; it
// waits for the next call and returns a generic [Conn].
func (l *ATLSServerListener) Accept() (net.Conn, error) {
conn := C.tls_server_accept(l.tlsListener)
if conn == nil {
return &ATLSConn{tlsConn: nil}, nil
}
return &ATLSConn{tlsConn: conn}, nil
}
// close stops listening on the TCP address.
// already Accepted connections are not closed.
func (l *ATLSServerListener) Close() error {
ret := C.tls_server_close(l.tlsListener)
if ret != 0 {
return errCloseTLS
}
return nil
}
// addr returns the listener's network address, a [*TCPAddr].
// the Addr returned is shared by all invocations of Addr, so
// do not modify it.
func (l *ATLSServerListener) Addr() net.Addr {
cIP := C.tls_return_addr(&l.tlsListener.addr)
defer C.free(unsafe.Pointer(cIP))
ip := C.GoString(cIP)
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return nil
}
port := C.tls_return_port(&l.tlsListener.addr)
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
}
type ATLSConn struct {
tlsConn *C.tls_connection
fdReadMutex sync.Mutex
fdWriteMutex sync.Mutex
fdDelayMutex sync.Mutex
}
func (c *ATLSConn) Read(b []byte) (int, error) {
c.fdReadMutex.Lock()
defer c.fdReadMutex.Unlock()
if c.tlsConn == nil {
return 0, errConnFailed
}
n := int(C.tls_read(c.tlsConn, unsafe.Pointer(&b[0]), C.int(len(b))))
if n > 0 {
return n, nil
}
// call the C function SSL_get_error to interpret the error.
errCode := int(C.SSL_get_error(c.tlsConn.ssl, C.int(n)))
// handle specific error codes returned by SSL_get_error.
switch errCode {
case noError:
return n, nil // no error.
case errorZeroReturn:
fmt.Fprintf(os.Stdout, "Connection closed by peer\n")
return 0, io.EOF // connection closed.
case errorWantRead:
fmt.Fprintf(os.Stderr, "Operation read incomplete, retry later\n")
return 0, nil // non-fatal, just retry later.
case errorWantWrite:
fmt.Fprintf(os.Stderr, "Operation write incomplete, retry later\n")
return 0, nil // non-fatal, just retry later.
case errorSyscall:
fmt.Fprintf(os.Stderr, "I/O error\n")
return 0, syscall.ECONNRESET // return connection reset error.
case errorSsl:
fmt.Fprintf(os.Stderr, "I/O error\n")
return 0, syscall.ECONNRESET // return connection reset error.
default:
fmt.Fprintf(os.Stderr, "SSL error occurred: %d\n", errCode)
return 0, fmt.Errorf("SSL error")
}
}
func (c *ATLSConn) Write(b []byte) (int, error) {
c.fdWriteMutex.Lock()
defer c.fdWriteMutex.Unlock()
if c.tlsConn == nil {
return 0, errConnFailed
}
n := int(C.tls_write(c.tlsConn, unsafe.Pointer(&b[0]), C.int(len(b))))
if n < 0 {
return 0, errWrite
}
return n, nil
}
func (c *ATLSConn) Close() error {
c.fdReadMutex.Lock()
defer c.fdReadMutex.Unlock()
c.fdWriteMutex.Lock()
defer c.fdWriteMutex.Unlock()
c.fdDelayMutex.Lock()
defer c.fdDelayMutex.Unlock()
if c.tlsConn == nil {
return nil
}
for {
ret := C.tls_close(c.tlsConn)
if int(ret) == 0 {
c.fdDelayMutex.Unlock()
c.fdWriteMutex.Unlock()
c.fdReadMutex.Unlock()
time.Sleep(waitTime * time.Millisecond)
c.fdDelayMutex.Lock()
c.fdWriteMutex.Lock()
c.fdReadMutex.Lock()
} else if int(ret) < 0 {
c.tlsConn = nil
return errTLSConn
} else if int(ret) == 1 {
c.tlsConn = nil
break
}
}
return nil
}
func (c *ATLSConn) LocalAddr() net.Addr {
if c.tlsConn == nil {
return nil
}
cIP := C.tls_return_addr(&c.tlsConn.local_addr)
ipLength := C.strlen(cIP)
defer C.free(unsafe.Pointer(cIP))
ip := C.GoStringN(cIP, C.int(ipLength))
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
fmt.Println("Invalid IP address")
return nil
}
port := C.tls_return_port(&c.tlsConn.local_addr)
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
}
func (c *ATLSConn) RemoteAddr() net.Addr {
if c.tlsConn == nil {
return nil
}
cIP := C.tls_return_addr(&c.tlsConn.remote_addr)
if cIP == nil {
fmt.Println("RemoteAddr error while fetching ip")
return nil
}
ipLength := C.strlen(cIP)
defer C.free(unsafe.Pointer(cIP))
ip := C.GoStringN(cIP, C.int(ipLength))
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
fmt.Println("Invalid IP address")
return nil
}
port := C.tls_return_port(&c.tlsConn.remote_addr)
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
}
func (c *ATLSConn) SetDeadline(t time.Time) error {
c.fdDelayMutex.Lock()
defer c.fdDelayMutex.Unlock()
if c.tlsConn == nil {
return nil
}
sec, usec := timeToTimeout(t)
if C.set_socket_read_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
return errReadDeadline
}
if C.set_socket_write_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
return errWriteDeadline
}
return nil
}
func (c *ATLSConn) SetReadDeadline(t time.Time) error {
c.fdDelayMutex.Lock()
defer c.fdDelayMutex.Unlock()
if c.tlsConn == nil {
return nil
}
sec, usec := timeToTimeout(t)
if C.set_socket_read_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
return errReadDeadline
}
return nil
}
func (c *ATLSConn) SetWriteDeadline(t time.Time) error {
c.fdDelayMutex.Lock()
defer c.fdDelayMutex.Unlock()
if c.tlsConn == nil {
return nil
}
sec, usec := timeToTimeout(t)
if C.set_socket_write_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
return errWriteDeadline
}
return nil
}
func DialTLSClient(hostname string, port int) (net.Conn, error) {
cHostName := C.CString(hostname)
defer C.free(unsafe.Pointer(cHostName))
conn := C.new_tls_connection(cHostName, C.int(port))
if conn == nil {
return nil, errConnCreate
}
return &ATLSConn{tlsConn: conn}, nil
}
func timeToTimeout(t time.Time) (int, int) {
if t.IsZero() {
return 0, 0
}
d := time.Until(t)
if d <= 0 {
return 0, 0
}
seconds := int(d.Seconds())
microseconds := int(d.Nanoseconds()/1000) % 1_000_000
return seconds, microseconds
}
-101
View File
@@ -1,101 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"testing"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestListen(t *testing.T) {
cert := []byte("dummy_cert")
key := []byte("dummy_key")
cases := []struct {
name string
address string
err error
}{
{
name: "Valid address",
address: "127.0.0.1:8889",
err: nil,
},
{
name: "Invalid address format",
address: "127.0.0.1",
err: errListener,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
l, err := Listen(c.address, cert, key)
assert.True(t, errors.Contains(err, c.err))
if l != nil {
t.Cleanup(func() {
err := l.Close()
assert.NoError(t, err)
})
}
})
}
}
func TestATLSServerListener_Accept(t *testing.T) {
t.Run("Accepts connection", func(t *testing.T) {
listener, err := Listen("127.0.0.1:8887", []byte("dummy_cert"), []byte("dummy_key"))
assert.NoError(t, err)
t.Cleanup(func() {
err := listener.Close()
assert.NoError(t, err)
})
conn, err := listener.Accept()
assert.NoError(t, err)
assert.NotNil(t, conn)
})
}
func TestATLSConn_Read(t *testing.T) {
buffer := make([]byte, 1024)
t.Run("Read with nil connection", func(t *testing.T) {
conn := &ATLSConn{tlsConn: nil}
_, err := conn.Read(buffer)
assert.Error(t, err)
assert.Equal(t, err, errConnFailed)
})
}
func TestATLSConn_Write(t *testing.T) {
data := []byte("test data")
t.Run("Write with nil connection", func(t *testing.T) {
conn := &ATLSConn{tlsConn: nil}
_, err := conn.Write(data)
assert.Error(t, err)
assert.Equal(t, err, errConnFailed)
})
}
func TestATLSConn_DeadlineFunctions(t *testing.T) {
conn := &ATLSConn{}
t.Run("SetDeadline - valid time", func(t *testing.T) {
err := conn.SetDeadline(time.Now().Add(1 * time.Minute))
assert.NoError(t, err)
})
t.Run("SetReadDeadline - past time", func(t *testing.T) {
err := conn.SetReadDeadline(time.Now().Add(-1 * time.Minute))
assert.NoError(t, err)
})
t.Run("SetWriteDeadline - zero time", func(t *testing.T) {
err := conn.SetWriteDeadline(time.Time{})
assert.NoError(t, err)
})
}
+270
View File
@@ -0,0 +1,270 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"encoding/asn1"
"os"
"path/filepath"
"testing"
"github.com/absmach/magistrala/pkg/errors"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/protobuf/encoding/protojson"
)
const sevProductNameMilan = "Milan"
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
func TestGetPlatformProvider(t *testing.T) {
cases := []struct {
name string
platformType attestation.PlatformType
expectedError error
}{
{
name: "Valid platform type SNPvTPM",
platformType: attestation.SNPvTPM,
expectedError: nil,
},
{
name: "Valid platform type Azure",
platformType: attestation.Azure,
expectedError: nil,
},
{
name: "Valid platform type TDX",
platformType: attestation.TDX,
expectedError: nil,
},
{
name: "Invalid platform type",
platformType: 999,
expectedError: errors.New("unsupported platform type: 999"),
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
provider, err := getPlatformProvider(c.platformType)
if c.expectedError != nil {
assert.Error(t, err)
assert.Equal(t, c.expectedError.Error(), err.Error())
assert.Nil(t, provider)
} else {
assert.NoError(t, err)
assert.NotNil(t, provider)
}
})
}
}
func TestGetPlatformVerifier(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
cases := []struct {
name string
platformType attestation.PlatformType
expectedError error
}{
{
name: "Valid platform type SNPvTPM",
platformType: attestation.SNPvTPM,
expectedError: nil,
},
{
name: "Valid platform type Azure",
platformType: attestation.Azure,
expectedError: nil,
},
{
name: "Valid platform type TDX",
platformType: attestation.TDX,
expectedError: errors.New("unknown field \"pcr_values\""),
},
{
name: "Invalid platform type",
platformType: 999,
expectedError: errors.New("unsupported platform type: 999"),
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
verifier, err := getPlatformVerifier(c.platformType)
if c.expectedError != nil {
assert.Error(t, err)
assert.Contains(t, err.Error(), c.expectedError.Error())
assert.Nil(t, verifier)
} else {
assert.NoError(t, err)
assert.NotNil(t, verifier)
}
})
}
}
func TestGetOID(t *testing.T) {
cases := []struct {
name string
platformType attestation.PlatformType
expectedOID asn1.ObjectIdentifier
expectedError error
}{
{
name: "Valid platform type SNPvTPM",
platformType: attestation.SNPvTPM,
expectedOID: SNPvTPMOID,
expectedError: nil,
},
{
name: "Valid platform type Azure",
platformType: attestation.Azure,
expectedOID: AzureOID,
expectedError: nil,
},
{
name: "Valid platform type TDX",
platformType: attestation.TDX,
expectedOID: TDXOID,
expectedError: nil,
},
{
name: "Invalid platform type",
platformType: 999,
expectedOID: nil,
expectedError: errors.New("unsupported platform type: 999"),
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
oid, err := getOID(c.platformType)
if c.expectedError != nil {
assert.Error(t, err)
assert.Equal(t, c.expectedError.Error(), err.Error())
assert.Nil(t, oid)
} else {
assert.NoError(t, err)
assert.Equal(t, c.expectedOID, oid)
}
})
}
}
func TestGetPlatformTypeFromOID(t *testing.T) {
cases := []struct {
name string
oid asn1.ObjectIdentifier
expectedType attestation.PlatformType
expectedError error
}{
{
name: "Valid OID for SNPvTPM",
oid: SNPvTPMOID,
expectedType: attestation.SNPvTPM,
expectedError: nil,
},
{
name: "Valid OID for Azure",
oid: AzureOID,
expectedType: attestation.Azure,
expectedError: nil,
},
{
name: "Valid OID for TDX",
oid: TDXOID,
expectedType: attestation.TDX,
expectedError: nil,
},
{
name: "Invalid OID",
oid: asn1.ObjectIdentifier{1, 2, 3},
expectedType: 0,
expectedError: errors.New("unsupported OID: 1.2.3"),
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
pType, err := GetPlatformTypeFromOID(c.oid)
if c.expectedError != nil {
assert.Error(t, err)
assert.Equal(t, c.expectedError.Error(), err.Error())
assert.Equal(t, attestation.PlatformType(0), pType)
} else {
assert.NoError(t, err)
assert.Equal(t, c.expectedType, pType)
}
})
}
}
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
file, err := os.ReadFile("../../attestation.bin")
require.NoError(t, err)
if len(file) < abi.ReportSize {
file = append(file, make([]byte, abi.ReportSize-len(file))...)
}
rr, err := abi.ReportCertsToProto(file)
require.NoError(t, err)
return rr
}
func setAttestationPolicy(rr *sevsnp.Attestation, policyDirectory string) error {
attestationPolicyFile, err := os.ReadFile("../../scripts/attestation_policy/attestation_policy.json")
if err != nil {
return err
}
unmarshalOptions := protojson.UnmarshalOptions{DiscardUnknown: true}
err = unmarshalOptions.Unmarshal(attestationPolicyFile, policy)
if err != nil {
return err
}
policy.Config.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
policy.Config.Policy.FamilyId = rr.Report.FamilyId
policy.Config.Policy.ImageId = rr.Report.ImageId
policy.Config.Policy.Measurement = rr.Report.Measurement
policy.Config.Policy.HostData = rr.Report.HostData
policy.Config.Policy.ReportIdMa = rr.Report.ReportIdMa
policy.Config.RootOfTrust.ProductLine = sevProductNameMilan
policyByte, err := vtpm.ConvertPolicyToJSON(&policy)
if err != nil {
return err
}
policyPath := filepath.Join(policyDirectory, "attestation_policy.json")
err = os.WriteFile(policyPath, policyByte, 0o644)
if err != nil {
return nil
}
attestation.AttestationPolicyPath = policyPath
return nil
}
-322
View File
@@ -1,322 +0,0 @@
#include "extensions.h"
#include <stdio.h>
#include <string.h>
#include <openssl/sha.h>
#include <openssl/rand.h>
#include <openssl/x509.h>
#include <fcntl.h>
#include <unistd.h>
extern int callVerificationValidationCallback(int platformType, const u_char* pubKey, int pubKeyLen, const u_char* quote, int quoteSize, const u_char* teeNonceByte, const u_char* vTPMNonceByte);
extern u_char* callFetchAttestationCallback(int platformType, const u_char* pubKey, int pubKeyLen, const u_char* teeNonceByte, const u_char* vTPMNonceByte, unsigned long* outlen);
extern uintptr_t validationVerificationCallback(int teeType);
extern uintptr_t getPlatformTypeHandle(int platformType, u_char *teeNonce, u_char *vtpmNonce);
extern int returnCCPlatformType();
int triggerVerificationValidationCallback(int platformType, u_char* pub_key, int pub_key_len, u_char *quote, int quote_size, u_char *tee_nonce, u_char *vtpm_nonce) {
if (quote == NULL || vtpm_nonce == NULL || tee_nonce == NULL || pub_key == NULL) {
fprintf(stderr, "attestation and noce and public key cannot be NULL\n");
return -1;
}
return callVerificationValidationCallback(platformType, pub_key, pub_key_len, quote, quote_size, tee_nonce, vtpm_nonce);
}
u_char* triggerFetchAttestationCallback(int platformType, u_char* pub_key, int pub_key_len, char *tee_nonce, char *vtpm_nonce, unsigned long *outlen) {
if(tee_nonce == NULL || vtpm_nonce == NULL) {
fprintf(stderr, "Report data cannot be NULL");
return NULL;
}
return callFetchAttestationCallback(platformType, pub_key, pub_key_len, tee_nonce, vtpm_nonce, outlen);
}
/*
Evidence request extension
- Contains a random nonce that goes into the attestation report
- Is sent in the ClientHello message
*/
void evidence_request_ext_free_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char *out,
void *add_arg)
{
free((void *)out);
}
int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char **out,
size_t *outlen, X509 *x,
size_t chainidx, int *al,
void *add_arg)
{
switch (context)
{
case SSL_EXT_CLIENT_HELLO:
{
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
evidence_request *er = (evidence_request*)malloc(sizeof(evidence_request));
if (er == NULL) {
perror("could not allocate memory");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
if (ext_data != NULL) {
if (RAND_bytes(ext_data->er.vtpm_nonce, NONCE_RANDOM_SIZE) != 1) {
perror("could not generate random bytes for vtpm nonce, will use SSL client random");
SSL_get_client_random(s, ext_data->er.vtpm_nonce, NONCE_RANDOM_SIZE);
}
if (RAND_bytes(ext_data->er.tee_nonce, REPORT_DATA_SIZE) != 1) {
perror("could not generate random bytes for tee nonce, will use SSL client random");
SSL_get_client_random(s, ext_data->er.tee_nonce, REPORT_DATA_SIZE);
}
} else {
fprintf(stderr, "add_arg is NULL\n");
free(er);
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
memcpy(er->vtpm_nonce, ext_data->er.vtpm_nonce, NONCE_RANDOM_SIZE);
memcpy(er->tee_nonce, ext_data->er.tee_nonce, REPORT_DATA_SIZE);
*out = (const u_char *)er;
*outlen = sizeof(evidence_request);
return 1;
}
case SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS:
{
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
if (ext_data != NULL) {
int32_t *platform_type = (int32_t*)malloc(sizeof(int32_t));
if (platform_type == NULL) {
perror("could not allocate memory");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
*platform_type = returnCCPlatformType();
ext_data->platform_type = *platform_type;
*out = (u_char*)platform_type;
*outlen = sizeof(int32_t);
} else {
fprintf(stderr, "add_arg is NULL\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
return 1;
}
default:
break;
}
fprintf(stderr, "bad context\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char *in,
size_t inlen, X509 *x,
size_t chainidx, int *al,
void *parse_arg)
{
switch (context)
{
case SSL_EXT_CLIENT_HELLO:
{
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
evidence_request *er = (evidence_request*)in;
if (ext_data != NULL) {
memcpy(ext_data->er.vtpm_nonce, er->vtpm_nonce, NONCE_RANDOM_SIZE);
memcpy(ext_data->er.tee_nonce, er->tee_nonce, REPORT_DATA_SIZE);
} else {
fprintf(stderr, "parse_arg is NULL\n");
return 0;
}
return 1;
}
case SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS:
{
int *platform_type = (int*)in;
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
if (ext_data != NULL) {
ext_data->platform_type = *platform_type;
} else {
fprintf(stderr, "parse_arg is NULL\n");
return 0;
}
return 1;
}
default:
fprintf(stderr, "bad context\n");
return 0;
}
}
/*
Attestation Certificate extension
- Contains the attestation report
- The attestation report contains the hash of the nonce, the Public Key of the x.509 Agent certificate, and the vTPM AK
*/
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char *out,
void *add_arg)
{
free((void *)out);
}
int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char **out,
size_t *outlen, X509 *x,
size_t chainidx, int *al,
void *add_arg)
{
switch (context)
{
case SSL_EXT_CLIENT_HELLO:
return 1;
case SSL_EXT_TLS1_3_CERTIFICATE:
{
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
if (ext_data != NULL) {
u_char *quote;
size_t len = 0;
EVP_PKEY *pkey = NULL;
u_char *pubkey_buf = NULL;
int pubkey_len = 0;
if (x != NULL) {
pkey = X509_get_pubkey(x);
if (pkey == NULL) {
fprintf(stderr, "Failed to extract public key from certificate\n");
return -1;
}
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
if (pubkey_len <= 0) {
fprintf(stderr, "Failed to convert public key to DER format\n");
EVP_PKEY_free(pkey);
return -1;
}
} else {
fprintf(stderr, "agent certificate must be used for aTLS\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
quote = triggerFetchAttestationCallback(ext_data->platform_type, pubkey_buf, pubkey_len, ext_data->er.tee_nonce, ext_data->er.vtpm_nonce, &len);
if (quote == NULL) {
fprintf(stderr, "attestation report is NULL\n");
*al = SSL_AD_INTERNAL_ERROR;
EVP_PKEY_free(pkey);
OPENSSL_free(pubkey_buf);
return -1;
}
EVP_PKEY_free(pkey);
OPENSSL_free(pubkey_buf);
*out = quote;
*outlen = len;
return 1;
} else {
fprintf(stderr, "add_arg is NULL\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
}
default:
fprintf(stderr, "bad context\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
}
int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char *in,
size_t inlen, X509 *x,
size_t chainidx, int *al,
void *parse_arg)
{
switch (context)
{
case SSL_EXT_CLIENT_HELLO:
return 1;
case SSL_EXT_TLS1_3_CERTIFICATE:
{
if (x != NULL) {
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
if (ext_data != NULL) {
char *quote = (char*)malloc(inlen*sizeof(char));
EVP_PKEY *pkey = NULL;
u_char *pubkey_buf = NULL;
int pubkey_len = 0;
int res = 0;
if (quote == NULL) {
perror("could not allocate memory");
return 0;
}
pkey = X509_get_pubkey(x);
if (pkey == NULL) {
fprintf(stderr, "Failed to extract public key from certificate\n");
return -1;
}
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
if (pubkey_len <= 0) {
fprintf(stderr, "Failed to convert public key to DER format\n");
EVP_PKEY_free(pkey);
return -1;
}
memcpy(quote, in, inlen);
res = triggerVerificationValidationCallback(ext_data->platform_type,
pubkey_buf,
pubkey_len,
quote,
inlen,
(u_char*)&ext_data->er.tee_nonce,
(u_char*)&ext_data->er.vtpm_nonce);
free(quote);
EVP_PKEY_free(pkey);
OPENSSL_free(pubkey_buf);
if (res != 0) {
fprintf(stderr, "verification and validation failed, aborting connection\n");
return 0;
}
} else {
fprintf(stderr, "parse_arg is NULL\n");
return 0;
}
return 1;
} else {
fprintf(stderr, "agent certificates must be used for aTLS\n");
return 0;
}
}
default:
fprintf(stderr, "bad context\n");
return 0;
}
}
-92
View File
@@ -1,92 +0,0 @@
#ifndef ATLS_EXTENSION_H
#define ATLS_EXTENSION_H
#include <openssl/ssl.h>
#include <arpa/inet.h>
#define EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE 65
#define ATTESTATION_CERTIFICATE_EXTENSION_TYPE 66
#define REPORT_DATA_SIZE 64
#define NONCE_RANDOM_SIZE 32
#define TLS_CLIENT_CTX 0
#define TLS_SERVER_CTX 1
typedef struct evidence_request
{
char vtpm_nonce[NONCE_RANDOM_SIZE];
char tee_nonce[REPORT_DATA_SIZE];
} evidence_request;
typedef struct tls_extension_data
{
int platform_type;
evidence_request er;
} tls_extension_data;
typedef struct tls_server_connection
{
int server_fd;
char* cert;
int cert_len;
char* key;
int key_len;
struct sockaddr_storage addr;
} tls_server_connection;
typedef struct tls_connection
{
SSL_CTX *ctx;
SSL *ssl;
int socket_fd;
struct sockaddr_storage local_addr;
struct sockaddr_storage remote_addr;
tls_extension_data tls_ext_data;
} tls_connection;
tls_server_connection* start_tls_server(const char* cert, int cert_len, const char* key, int key_len, const char* ip, int port);
tls_connection* tls_server_accept(tls_server_connection *tls_server);
int tls_server_close(tls_server_connection *tls_server);
int tls_read(tls_connection *conn, void *buf, int num);
int tls_write(tls_connection *conn, const void *buf, int num);
int tls_close(tls_connection *conn);
tls_connection* new_tls_connection(char *address, int port);
int set_socket_read_timeout(tls_connection* conn, int timeout_sec, int timeout_usec);
int set_socket_write_timeout(tls_connection* conn, int timeout_sec, int timeout_usec);
char* tls_return_addr(struct sockaddr_storage *addr);
int tls_return_port(struct sockaddr_storage *addr);
// Extensions
void evidence_request_ext_free_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char *out,
void *add_arg);
int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char **out,
size_t *outlen, X509 *x,
size_t chainidx, int *al,
void *add_arg);
int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char *in,
size_t inlen, X509 *x,
size_t chainidx, int *al,
void *parse_arg);
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char *out,
void *add_arg);
int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char **out,
size_t *outlen, X509 *x,
size_t chainidx, int *al,
void *add_arg);
int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char *in,
size_t inlen, X509 *x,
size_t chainidx, int *al,
void *parse_arg);
#endif // ATLS_EXTENSION_H
-607
View File
@@ -1,607 +0,0 @@
#include "extensions.h"
#include <openssl/err.h>
#include <openssl/x509.h>
#include <openssl/evp.h>
#include <stdio.h>
#include <netdb.h>
#include <unistd.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <ifaddrs.h>
void init_openssl() {
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
}
void cleanup_openssl() {
EVP_cleanup();
}
int load_certificates_from_memory(SSL_CTX* ctx, const char* cert, int cert_len, const char* key, int key_len) {
BIO* cert_bio = NULL;
BIO* key_bio = NULL;
X509* x509_cert = NULL;
EVP_PKEY* pkey = NULL;
int success = 0;
cert_bio = BIO_new_mem_buf(cert, cert_len);
if (cert_bio == NULL) {
goto cleanup;
}
key_bio = BIO_new_mem_buf(key, key_len);
if (key_bio == NULL) {
goto cleanup;
}
x509_cert = PEM_read_bio_X509(cert_bio, NULL, 0, NULL);
if (x509_cert == NULL) {
goto cleanup;
}
pkey = PEM_read_bio_PrivateKey(key_bio, NULL, 0, NULL);
if (pkey == NULL) {
goto cleanup;
}
if (SSL_CTX_use_certificate(ctx, x509_cert) <= 0) {
goto cleanup;
}
if (SSL_CTX_use_PrivateKey(ctx, pkey) <= 0) {
goto cleanup;
}
if (SSL_CTX_check_private_key(ctx) <= 0) {
goto cleanup;
}
success = 1;
cleanup:
if (cert_bio) BIO_free(cert_bio);
if (key_bio) BIO_free(key_bio);
if (x509_cert) X509_free(x509_cert);
if (pkey) EVP_PKEY_free(pkey);
if (!success) {
ERR_print_errors_fp(stderr);
}
return success;
}
int enforce_tls1_3_only(SSL_CTX *ctx) {
if (SSL_CTX_set_min_proto_version(ctx, TLS1_3_VERSION) == 0) {
return 0;
}
if (SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION) == 0) {
return 0;
}
return 1;
}
SSL_CTX *create_context(int is_server) {
const SSL_METHOD *method;
SSL_CTX *ctx;
if (is_server) {
method = TLS_server_method();
} else {
method = TLS_client_method();
}
ctx = SSL_CTX_new(method);
if (!ctx) {
perror("Unable to create SSL context");
ERR_print_errors_fp(stderr);
return NULL;
}
if (!enforce_tls1_3_only(ctx)) {
fprintf(stderr, "could not enforce TLS1.3\n");
SSL_CTX_free(ctx);
return NULL;
}
return ctx;
}
int add_custom_tls_extension(SSL_CTX *ctx, tls_connection *conn) {
uint32_t flags_nonce = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS;
uint32_t flags_attestation = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_CERTIFICATE;
int ret = 1;
void *data = NULL;
if (conn != NULL) {
data = (void*)&conn->tls_ext_data;
}
ret = SSL_CTX_add_custom_ext(ctx,
EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE,
flags_nonce,
evidence_request_ext_add_cb,
evidence_request_ext_free_cb,
data,
evidence_request_ext_parse_cb,
data);
if (ret != 1) {
return 0;
}
ret = SSL_CTX_add_custom_ext(ctx,
ATTESTATION_CERTIFICATE_EXTENSION_TYPE,
flags_attestation,
attestation_certificate_ext_add_cb,
attestation_certificate_ext_free_cb,
data,
attestation_certificate_ext_parse_cb,
data);
if (ret != 1) {
return 0;
}
return ret;
}
// Function to start the TLS server
tls_server_connection* start_tls_server(const char* cert, int cert_len, const char* key, int key_len, const char* ip, int port) {
tls_server_connection *tls_server = (tls_server_connection*)malloc(sizeof(tls_server_connection));
int opt = 0;
if (tls_server == NULL) {
perror("memory could not be allocated");
return NULL;
}
init_openssl();
tls_server->cert = (char*)malloc(cert_len * sizeof(char));
if (tls_server->cert == NULL) {
perror("memory could not be allocated");
goto cleanup_tls_server;
}
tls_server->key = (char*)malloc(key_len * sizeof(char));
if (tls_server->key == NULL) {
perror("memory could not be allocated");
goto cleanup_cert;
}
memcpy(tls_server->cert, cert, cert_len);
memcpy(tls_server->key, key, key_len);
tls_server->cert_len = cert_len;
tls_server->key_len = key_len;
tls_server->server_fd = socket(AF_INET6, SOCK_STREAM, 0);
if (tls_server->server_fd < 0) {
perror("Unable to create socket");
goto cleanup_key;
}
// Enable both IPv4-mapped and IPv6 addresses
if (setsockopt(tls_server->server_fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) != 0) {
perror("setsockopt(IPV6_V6ONLY) failed");
goto cleanup_socket;
}
// Configure address structure
memset(&(tls_server->addr), 0, sizeof(tls_server->addr));
struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)&tls_server->addr;
addr6->sin6_family = AF_INET6;
addr6->sin6_port = htons(port);
// Set the appropriate address (IPv4-mapped if needed)
if (ip == NULL || (strlen(ip) + 1) < INET_ADDRSTRLEN) {
addr6->sin6_addr = in6addr_any;
} else if (strchr(ip, ':') != NULL) {
if (inet_pton(AF_INET6, ip, &(addr6->sin6_addr)) <= 0) {
perror("Invalid IPv6 address");
goto cleanup_socket;
}
} else {
struct in_addr ipv4_addr;
if (inet_pton(AF_INET, ip, &ipv4_addr) <= 0) {
perror("Invalid IPv4 address");
goto cleanup_socket;
}
memset(&addr6->sin6_addr, 0, sizeof(addr6->sin6_addr));
addr6->sin6_addr.s6_addr[10] = 0xff;
addr6->sin6_addr.s6_addr[11] = 0xff;
memcpy(&addr6->sin6_addr.s6_addr[12], &ipv4_addr, sizeof(ipv4_addr));
}
if (bind(tls_server->server_fd, (struct sockaddr*)&(tls_server->addr), sizeof(tls_server->addr)) < 0) {
perror("Unable to bind");
goto cleanup_socket;
}
if (listen(tls_server->server_fd, SOMAXCONN) < 0) {
perror("Unable to listen");
goto cleanup_socket;
}
printf("Listening on port: %d\n", port);
return tls_server;
// Cleanup labels
cleanup_socket:
close(tls_server->server_fd);
cleanup_key:
free(tls_server->key);
cleanup_cert:
free(tls_server->cert);
cleanup_tls_server:
free(tls_server);
return NULL;
}
// Function to accept a client connection
tls_connection* tls_server_accept(tls_server_connection *tls_server) {
uint32_t len = sizeof(struct sockaddr_storage);
tls_connection *conn = (tls_connection*)malloc(sizeof(tls_connection));
int client_fd = -1;
int ret = 0;
if (conn == NULL) {
perror("Unable to allocate memory for tls_connection");
return NULL;
}
conn->ctx = NULL;
conn->ssl = NULL;
conn->socket_fd = -1;
// Initialize the context
conn->ctx = create_context(TLS_SERVER_CTX);
if (conn->ctx == NULL) {
perror("Unable to create context");
goto cleanup_conn;
}
// Load certificates
if (!load_certificates_from_memory(conn->ctx, tls_server->cert, tls_server->cert_len, tls_server->key, tls_server->key_len)) {
fprintf(stderr, "Failed to load certificates\n");
goto cleanup_ctx;
}
// Add custom TLS extension
ret = add_custom_tls_extension(conn->ctx, conn);
if (!ret) {
perror("Unable to add custom tls extensions");
goto cleanup_ctx;
}
// Accept client connection
client_fd = accept(tls_server->server_fd, NULL, NULL);
if (client_fd < 0) {
perror("Unable to accept connection");
goto cleanup_ctx;
}
// Create SSL object
conn->ssl = SSL_new(conn->ctx);
if (conn->ssl == NULL) {
perror("Unable to create SSL object");
goto cleanup_fd;
}
// Set file descriptor
conn->socket_fd = client_fd;
SSL_set_fd(conn->ssl, client_fd);
// Get local address
if (getsockname(client_fd, (struct sockaddr *)&conn->local_addr, &len) == -1) {
perror("getsockname failed during TLS server accept");
goto cleanup_ssl;
}
// Get remote address
if (getpeername(client_fd, (struct sockaddr *)&conn->remote_addr, &len) == -1) {
perror("getpeername failed during TLS server accept");
goto cleanup_ssl;
}
// Perform SSL handshake
ret = SSL_accept(conn->ssl);
if (ret <= 0) {
perror("SSL handshake failed during accept");
goto cleanup_ssl;
}
return conn;
cleanup_ssl:
if (conn->ssl) SSL_free(conn->ssl);
cleanup_fd:
if (client_fd >= 0) close(client_fd);
cleanup_ctx:
if (conn->ctx) SSL_CTX_free(conn->ctx);
cleanup_conn:
free(conn);
return NULL;
}
// Function to close the server
int tls_server_close(tls_server_connection *tls_server) {
close(tls_server->server_fd);
cleanup_openssl();
free(tls_server->cert);
free(tls_server->key);
free(tls_server);
return 0;
}
int tls_read(tls_connection *conn, void *buf, int num) {
return SSL_read(conn->ssl, buf, num);
}
int tls_write(tls_connection *conn, const void *buf, int num) {
if (SSL_get_shutdown(conn->ssl) & SSL_SENT_SHUTDOWN) {
return 0;
}
return SSL_write(conn->ssl, buf, num);
}
int tls_close(tls_connection *conn) {
if (conn != NULL) {
if (conn->ssl != NULL) {
int ret = 0;
if (SSL_has_pending(conn->ssl) == 1 || (SSL_get_shutdown(conn->ssl) & SSL_SENT_SHUTDOWN)) {
int num = SSL_pending(conn->ssl);
char c[num];
int res = 0;
int end = 0;
res = SSL_read(conn->ssl, (void*)c, num);
res = SSL_get_error(conn->ssl, res);
if (res == SSL_ERROR_ZERO_RETURN) {
end = 1;
} else if (res != SSL_ERROR_NONE) {
fprintf(stderr, "SSL_read failed in TLS close call\n");
end = 1;
}
if ((SSL_get_shutdown(conn->ssl) & SSL_RECEIVED_SHUTDOWN) || end == 1) {
ret = SSL_shutdown(conn->ssl);
}
} else {
ret = SSL_shutdown(conn->ssl);
}
if (ret < 0) {
ret = SSL_get_error(conn->ssl, ret);
fprintf(stderr, "SSL did not shutdown correctly, error code: %d\n", ret);
free(conn);
close(conn->socket_fd);
conn = NULL;
return -1;
} else if (ret == 0) {
return 0;
}
conn->ssl = NULL;
}
if (conn->socket_fd >= 0) {
close(conn->socket_fd);
conn->socket_fd = -1;
}
SSL_free(conn->ssl);
if (conn->ctx != NULL) {
SSL_CTX_free(conn->ctx);
conn->ctx = NULL;
}
free(conn);
conn = NULL;
return 1;
}
return 1;
}
char* tls_return_addr(struct sockaddr_storage *addr) {
socklen_t addr_len = sizeof(struct sockaddr_storage);
int inet_size =addr->ss_family == AF_INET ? INET_ADDRSTRLEN : INET6_ADDRSTRLEN;
char *ip_str = (char*)malloc(inet_size*sizeof(char));
void * addr_ptr;
if (ip_str == NULL) {
perror("memory could not be allocated");
return NULL;
}
if (addr->ss_family == AF_INET) {
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
addr_ptr = &(ipv4->sin_addr);
} else if (addr->ss_family == AF_INET6) {
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
addr_ptr = &(ipv6->sin6_addr);
} else {
fprintf(stderr, "unknown family: %d\n", addr->ss_family);
free(ip_str);
return NULL;
}
if (inet_ntop(addr->ss_family, addr_ptr, ip_str, inet_size) == NULL) {
perror("inet_ntop failed");
free(ip_str);
return NULL;
}
return ip_str;
}
int tls_return_port(struct sockaddr_storage *addr) {
if (addr->ss_family == AF_INET) {
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
return ntohs(ipv4->sin_port);
} else if (addr->ss_family == AF_INET6) {
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
return ntohs(ipv6->sin6_port);
}
fprintf(stderr, "cannot return port from unknown family: %d\n", addr->ss_family);
return -1;
}
tls_connection* new_tls_connection(char *address, int port) {
SSL_CTX *ctx = NULL;
SSL *ssl = NULL;
int socket_fd = -1;
int status;
struct addrinfo hints, *res = NULL, *p = NULL;
char port_str[6];
tls_connection *conn = NULL;
socklen_t addr_len;
int ret = 0;
conn = (tls_connection*)malloc(sizeof(tls_connection));
if (!conn) {
perror("Failed to allocate memory for atls connection");
return NULL;
}
// Format the port string
snprintf(port_str, sizeof(port_str), "%d", port);
// Initialize OpenSSL
init_openssl();
// Create SSL context
ctx = create_context(TLS_CLIENT_CTX);
if (!ctx) {
perror("Could not create context");
goto cleanup_conn;
}
conn->ctx = ctx;
// Add custom TLS extension
ret = add_custom_tls_extension(conn->ctx, conn);
if (!ret) {
perror("Unable to add custom tls extensions");
goto cleanup_ctx;
}
// Prepare the address info hints
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = 0;
hints.ai_protocol = 0;
// Get address info
status = getaddrinfo(address, port_str, &hints, &res);
if (status != 0) {
perror("getaddrinfo error");
goto cleanup_ctx;
}
// Iterate through the results and try to connect
for (p = res; p != NULL; p = p->ai_next) {
socket_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (socket_fd < 0) {
perror("unable to create socket");
continue;
}
if (connect(socket_fd, p->ai_addr, p->ai_addrlen)) {
close(socket_fd);
continue;
}
memcpy(&conn->local_addr, p->ai_addr, p->ai_addrlen);
break;
}
freeaddrinfo(res);
if (p == NULL) {
goto cleanup_ctx;
}
conn->socket_fd = socket_fd;
// Retrieve and store the remote address
addr_len = sizeof(conn->remote_addr);
if (getpeername(socket_fd, (struct sockaddr *)&conn->remote_addr, &addr_len) == -1) {
perror("getpeername failed");
goto cleanup_socket;
}
// Create the SSL structure
ssl = SSL_new(ctx);
if (!ssl) {
perror("Failed to create SSL object");
goto cleanup_socket;
}
if (!SSL_set_fd(ssl, socket_fd)) {
perror("Failed to set SSL file descriptor");
goto cleanup_ssl;
}
conn->ssl = ssl;
// Perform the SSL handshake
if (SSL_connect(ssl) <= 0) {
fprintf(stderr, "SSL handshake failed\n");
goto cleanup_ssl;
}
return conn;
cleanup_ssl:
if (ssl) SSL_free(ssl);
cleanup_socket:
if (socket_fd >= 0) close(socket_fd);
cleanup_ctx:
if (ctx) SSL_CTX_free(ctx);
cleanup_conn:
free(conn);
return NULL;
}
int set_socket_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
struct timeval timeout;
timeout.tv_sec = timeout_sec;
timeout.tv_usec = timeout_usec;
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
return -1;
}
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
return -1;
}
return 0;
}
int set_socket_read_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
struct timeval timeout;
timeout.tv_sec = timeout_sec;
timeout.tv_usec = timeout_usec;
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
return -1;
}
return 0;
}
int set_socket_write_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
struct timeval timeout;
timeout.tv_sec = timeout_sec;
timeout.tv_usec = timeout_usec;
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
return -1;
}
return 0;
}
+1 -1
View File
@@ -134,7 +134,7 @@ func (a verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
}
func (a verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
return vtpm.VerifyQuote(report, nil, vTpmNonce, a.writer, a.Policy)
return vtpm.VerifyQuote(report, vTpmNonce, a.writer, a.Policy)
}
func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
+11 -57
View File
@@ -6,8 +6,6 @@ package vtpm
import (
"bytes"
"crypto"
"crypto/sha256"
"crypto/sha512"
"encoding/hex"
"encoding/json"
"fmt"
@@ -105,14 +103,12 @@ func ExtendPCR(pcrIndex int, value []byte) error {
}
type provider struct {
pubKey []byte
teeAttestaion bool
vmpl uint
}
func NewProvider(pubKey []byte, teeAttestation bool, vmpl uint) attestation.Provider {
func NewProvider(teeAttestation bool, vmpl uint) attestation.Provider {
return &provider{
pubKey: pubKey,
teeAttestaion: teeAttestation,
vmpl: vmpl,
}
@@ -140,19 +136,17 @@ func (v provider) AzureAttestationToken(tokenNonce []byte) ([]byte, error) {
}
type verifier struct {
pubKey []byte
writer io.Writer
Policy *attestation.Config
}
func NewVerifier(pubKey []byte, writer io.Writer) attestation.Verifier {
func NewVerifier(writer io.Writer) attestation.Verifier {
policy := &attestation.Config{
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
PcrConfig: &attestation.PcrConfig{},
}
return &verifier{
pubKey: pubKey,
writer: writer,
Policy: policy,
}
@@ -160,11 +154,10 @@ func NewVerifier(pubKey []byte, writer io.Writer) attestation.Verifier {
func NewVerifierWithPolicy(pubKey []byte, writer io.Writer, policy *attestation.Config) attestation.Verifier {
if policy == nil {
return NewVerifier(pubKey, writer)
return NewVerifier(writer)
}
return &verifier{
pubKey: pubKey,
writer: writer,
Policy: policy,
}
@@ -181,11 +174,11 @@ func (v verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
}
func (v verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
return VerifyQuote(report, v.pubKey, vTpmNonce, v.writer, v.Policy)
return VerifyQuote(report, vTpmNonce, v.writer, v.Policy)
}
func (v verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
return VTPMVerify(report, v.pubKey, teeNonce, vTpmNonce, v.writer, v.Policy)
return VTPMVerify(report, teeNonce, vTpmNonce, v.writer, v.Policy)
}
func (v verifier) JSONToPolicy(path string) error {
@@ -208,8 +201,8 @@ func Attest(teeNonce []byte, vTPMNonce []byte, teeAttestaion bool, vmpl uint) ([
return marshalQuote(attestation)
}
func VTPMVerify(quote []byte, pubKeyTLS []byte, teeNonce []byte, vtpmNonce []byte, writer io.Writer, policy *attestation.Config) error {
if err := VerifyQuote(quote, pubKeyTLS, vtpmNonce, writer, policy); err != nil {
func VTPMVerify(quote []byte, teeNonce []byte, vtpmNonce []byte, writer io.Writer, policy *attestation.Config) error {
if err := VerifyQuote(quote, vtpmNonce, writer, policy); err != nil {
return fmt.Errorf("failed to verify vTPM quote: %v", err)
}
@@ -227,7 +220,7 @@ func VTPMVerify(quote []byte, pubKeyTLS []byte, teeNonce []byte, vtpmNonce []byt
return nil
}
func VerifyQuote(quote []byte, pubKeyTLS []byte, vtpmNonce []byte, writer io.Writer, policy *attestation.Config) error {
func VerifyQuote(quote []byte, vtpmNonce []byte, writer io.Writer, policy *attestation.Config) error {
attestation := &attest.Attestation{}
err := proto.Unmarshal(quote, attestation)
@@ -251,9 +244,7 @@ func VerifyQuote(quote []byte, pubKeyTLS []byte, vtpmNonce []byte, writer io.Wri
return errors.Wrap(fmt.Errorf("failed to verify attestation"), err)
}
s256, s384 := calculatePCRTLSKey(pubKeyTLS)
if err := checkExpectedPCRValues(attestation, s256, s384, policy); err != nil {
if err := checkExpectedPCRValues(attestation, policy); err != nil {
return fmt.Errorf("PCR values do not match expected PCR values: %w", err)
}
@@ -330,39 +321,23 @@ func addTEEAttestation(attestation *attest.Attestation, nonce []byte, vmpl uint)
return nil
}
func checkExpectedPCRValues(attQuote *attest.Attestation, ePcr256, ePcr384 []byte, policy *attestation.Config) error {
func checkExpectedPCRValues(attQuote *attest.Attestation, policy *attestation.Config) error {
quotes := attQuote.GetQuotes()
for i := range quotes {
quote := quotes[i]
var pcrMap map[string]string
var pcr15 []byte
switch quote.Pcrs.Hash {
case ptpm.HashAlgo_SHA256:
pcrMap = policy.PcrConfig.PCRValues.Sha256
if ePcr256 == nil {
pcr15 = make([]byte, 32)
} else {
pcr15 = ePcr256
}
case ptpm.HashAlgo_SHA384:
pcrMap = policy.PcrConfig.PCRValues.Sha384
if ePcr384 == nil {
pcr15 = make([]byte, 48)
} else {
pcr15 = ePcr384
}
case ptpm.HashAlgo_SHA1:
pcrMap = policy.PcrConfig.PCRValues.Sha1
pcr15 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
default:
return errors.Wrap(ErrNoHashAlgo, fmt.Errorf("algo: %s", ptpm.HashAlgo_name[int32(quote.Pcrs.Hash)]))
}
pcr15Index := uint32(15)
if !bytes.Equal(quote.Pcrs.Pcrs[pcr15Index], pcr15) {
return fmt.Errorf("for algo %s PCR[15] expected %s but found %s", ptpm.HashAlgo_name[int32(quote.Pcrs.Hash)], hex.EncodeToString(pcr15), hex.EncodeToString(quote.Pcrs.Pcrs[pcr15Index]))
}
for i, v := range pcrMap {
index, err := strconv.ParseInt(i, 10, 32)
if err != nil {
@@ -380,27 +355,6 @@ func checkExpectedPCRValues(attQuote *attest.Attestation, ePcr256, ePcr384 []byt
return nil
}
// Return SHA256 and SHA384 values of the input public key.
func calculatePCRTLSKey(pubKey []byte) ([]byte, []byte) {
if len(pubKey) == 0 {
return nil, nil
}
init256 := make([]byte, Hash256)
init384 := make([]byte, Hash384)
key256 := sha3.Sum256(pubKey)
key384 := sha3.Sum384(pubKey)
pcrValue256 := append(init256, key256[:]...)
pcrValue384 := append(init384, key384[:]...)
newPcr256 := sha256.Sum256(pcrValue256)
newPcr384 := sha512.Sum384(pcrValue384)
return newPcr256[:], newPcr384[:]
}
func getPCRValue(index int, algorithm tpm2.Algorithm) ([]byte, error) {
rwc, err := OpenTpm()
if err != nil {
+34 -40
View File
@@ -1,19 +1,16 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
//go:build cgo
package grpc
import (
"context"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"net"
"os"
"strconv"
"time"
"github.com/absmach/magistrala/pkg/errors"
@@ -36,12 +33,9 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, securit
attestation.AttestationPolicyPath = cfg.AttestationPolicy
var insecureSkipVerify bool = true
var rootCAs *x509.CertPool = nil
if len(cfg.ServerCAFile) > 0 {
insecureSkipVerify = false
// Read the certificate file
certPEM, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
@@ -66,11 +60,20 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, securit
security = withmaTLS
}
nonce := make([]byte, 64)
if _, err := rand.Read(nonce); err != nil {
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to generate nonce"), err)
}
encoded := hex.EncodeToString(nonce)
sni := fmt.Sprintf("%s.nonce", encoded)
tlsConfig := &tls.Config{
InsecureSkipVerify: insecureSkipVerify,
InsecureSkipVerify: true,
RootCAs: rootCAs,
ServerName: sni,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return verifyPeerCertificateATLS(rawCerts, verifiedChains, cfg)
return verifyPeerCertificateATLS(rawCerts, verifiedChains, nonce, rootCAs)
},
}
@@ -85,49 +88,40 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, securit
return credentials.NewTLS(tlsConfig), security, nil
}
func CustomDialer(ctx context.Context, addr string) (net.Conn, error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("could not create a custom dialer")
}
p, err := strconv.Atoi(port)
if err != nil {
return nil, fmt.Errorf("bad format of IP address: %v", err)
}
conn, err := atls.DialTLSClient(ip, p)
if err != nil {
return nil, fmt.Errorf("could not create TLS connection")
}
return conn, nil
}
func verifyPeerCertificateATLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, cfg AgentClientConfig) error {
if len(cfg.ServerCAFile) > 0 {
return nil
}
func verifyPeerCertificateATLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte, rootCAs *x509.CertPool) error {
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return errors.Wrap(errCertificateParse, err)
}
err = checkIfCertificateSelfSigned(cert)
err = checkIfCertificateSigned(cert, rootCAs)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
return nil
for _, ext := range cert.Extensions {
pType, err := atls.GetPlatformTypeFromOID(ext.Id)
if err == nil {
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal public key to DER format: %w", err)
}
return atls.VerifyCertificateExtension(ext.Value, pubKeyDER, nonce, pType)
}
}
return fmt.Errorf("attestation extension not found in certificate")
}
func checkIfCertificateSelfSigned(cert *x509.Certificate) error {
certPool := x509.NewCertPool()
certPool.AddCert(cert)
func checkIfCertificateSigned(cert *x509.Certificate, rootCAs *x509.CertPool) error {
if rootCAs == nil {
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: certPool,
Roots: rootCAs,
CurrentTime: time.Now(),
}
-16
View File
@@ -1,16 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
//go:build !cgo
package grpc
import (
"fmt"
"google.golang.org/grpc/credentials"
)
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error) {
return nil, fmt.Errorf("aTLS is not supported without CGO. Please rebuild with CGO_ENABLED=1")
}
+1 -1
View File
@@ -377,7 +377,7 @@ func TestCheckIfCertificateSelfSigned(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkIfCertificateSelfSigned(tt.cert)
err := checkIfCertificateSigned(tt.cert, nil)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
-1
View File
@@ -148,7 +148,6 @@ func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) {
}
opts = append(opts, grpc.WithTransportCredentials(tc))
opts = append(opts, grpc.WithContextDialer(CustomDialer))
secure = sec
} else {