COCOS-192 - Add support for attested TLS (#279)

* add draft tls extension

* add client support for ipv6

* remove vscode

* add evidence request server payload

* clean up the code

* add fetch and verify for quote provider

* add build parameters for buildroot

* change Makefile to always enable CGO

* fix ci

* add malloc check for NULL

* add copyright

* renamed files and fix cgo lint

* fix cache test

* fix server tests

* remove ineffective assignment

* fix no-TLS connection

* add check for SSL_set_fd failure

* add tests for verification of attestation

* fix CI

* fix failing tests

* fix backend tests

* remove commented code

* separate verify and validate function

* fix failing test

* Simplify function name

---------

Co-authored-by: ultraviolet <cocosai@ultraviolet.local.pragmatic-it.com>
This commit is contained in:
Danko Miladinovic
2024-11-04 19:10:34 +01:00
committed by GitHub
parent 6f747190b9
commit e372cfc219
28 changed files with 2056 additions and 591 deletions
+1 -1
View File
@@ -1,7 +1,7 @@
BUILD_DIR = build
SERVICES = manager agent cli
BACKEND_INFO = backend_info
CGO_ENABLED ?= 0
CGO_ENABLED ?= 1
GOARCH ?= amd64
VERSION ?= $(shell git describe --abbrev=0 --tags --always)
COMMIT ?= $(shell git rev-parse HEAD)
-13
View File
@@ -1,13 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
//go:build !embed
// +build !embed
package quoteprovider
import "github.com/google/go-sev-guest/client"
func GetQuoteProvider() (client.QuoteProvider, error) {
return client.GetQuoteProvider()
}
+2 -2
View File
@@ -18,10 +18,10 @@ import (
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/python"
"github.com/ultravioletrs/cocos/agent/events/mocks"
"github.com/ultravioletrs/cocos/agent/quoteprovider"
mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
"github.com/ultravioletrs/cocos/agent/statemachine"
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
mocks2 "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc/metadata"
)
+3 -56
View File
@@ -17,12 +17,10 @@ import (
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/tools/lib/report"
"github.com/google/go-sev-guest/validate"
"github.com/google/go-sev-guest/verify"
"github.com/google/go-sev-guest/verify/trust"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/wrapperspb"
)
@@ -269,7 +267,7 @@ func (cli *CLI) NewValidateAttestationValidationCmd() *cobra.Command {
return
}
if err := verifyAndValidateAttestation(attestation); err != nil {
if err := quoteprovider.VerifyAndValidate(attestation, &cfg); err != nil {
printError(cmd, "Attestation validation and verification failed with error: %v ❌ ", err)
return
}
@@ -464,57 +462,6 @@ func (cli *CLI) NewValidateAttestationValidationCmd() *cobra.Command {
return cmd
}
func verifyAndValidateAttestation(attestation []byte) error {
sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust)
if err != nil {
return err
}
if cfg.Policy.Product == nil {
productName := sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN
switch cfg.RootOfTrust.ProductLine {
case sevProductNameMilan:
productName = sevsnp.SevProduct_SEV_PRODUCT_MILAN
case sevProductNameGenoa:
productName = sevsnp.SevProduct_SEV_PRODUCT_GENOA
default:
}
if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN {
return fmt.Errorf("product name must be %s or %s", sevProductNameMilan, sevProductNameGenoa)
}
sopts.Product = &sevsnp.SevProduct{
Name: productName,
}
} else {
sopts.Product = cfg.Policy.Product
}
sopts.Getter = &trust.RetryHTTPSGetter{
Timeout: timeout,
MaxRetryDelay: maxRetryDelay,
Getter: &trust.SimpleHTTPSGetter{},
}
// Only take the attestation report and ignore everything else.
attestationPB, err := abi.ReportCertsToProto(attestation[:abi.ReportSize])
if err != nil {
return err
}
if err = verify.SnpAttestation(attestationPB, sopts); err != nil {
return err
}
opts, err := validate.PolicyToOptions(cfg.Policy)
if err != nil {
return err
}
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
return err
}
return nil
}
// parseConfig decodes config passed as json for check.Config struct.
// example
/* {
@@ -681,7 +628,7 @@ func getBase(val string) int {
}
func validateInput() error {
if len(cfg.RootOfTrust.CabundlePaths) != 0 || len(cfg.RootOfTrust.Cabundles) != 0 && cfg.RootOfTrust.Product == "" {
if len(cfg.RootOfTrust.CabundlePaths) != 0 || len(cfg.RootOfTrust.Cabundles) != 0 && cfg.RootOfTrust.ProductLine == "" {
return fmt.Errorf("product name must be set if CA bundles are provided")
}
+6 -11
View File
@@ -4,7 +4,6 @@ package cli
import (
"encoding/base64"
"encoding/json"
"fmt"
"os"
@@ -12,6 +11,7 @@ import (
"github.com/google/go-sev-guest/proto/check"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/protobuf/encoding/protojson"
)
type fieldType int
@@ -39,11 +39,6 @@ var (
errBackendField = errors.New("the specified field type does not exist in the backend information")
)
type AttestationConfiguration struct {
SNPPolicy *check.Policy `json:"snp_policy,omitempty"`
RootOfTrust *check.RootOfTrust `json:"root_of_trust,omitempty"`
}
func (cli *CLI) NewBackendCmd() *cobra.Command {
return &cobra.Command{
Use: "backend [command]",
@@ -114,27 +109,27 @@ func changeAttestationConfiguration(fileName, base64Data string, expectedLength
return errDataLength
}
ac := AttestationConfiguration{}
ac := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
backendInfo, err := os.ReadFile(fileName)
if err != nil {
return errors.Wrap(errReadingBackendInfoFile, err)
}
if err = json.Unmarshal(backendInfo, &ac); err != nil {
if err = protojson.Unmarshal(backendInfo, &ac); err != nil {
return errors.Wrap(errUnmarshalJSON, err)
}
switch field {
case measurementField:
ac.SNPPolicy.Measurement = data
ac.Policy.Measurement = data
case hostDataField:
ac.SNPPolicy.HostData = data
ac.Policy.HostData = data
default:
return errBackendField
}
fileJson, err := json.MarshalIndent(ac, "", " ")
fileJson, err := protojson.Marshal(&ac)
if err != nil {
return errors.Wrap(errMarshalJSON, err)
}
+7 -12
View File
@@ -4,13 +4,13 @@ package cli
import (
"encoding/base64"
"encoding/json"
"os"
"testing"
"github.com/google/go-sev-guest/proto/check"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protojson"
)
func TestChangeAttestationConfiguration(t *testing.T) {
@@ -18,14 +18,9 @@ func TestChangeAttestationConfiguration(t *testing.T) {
require.NoError(t, err)
defer os.Remove(tmpfile.Name())
initialConfig := AttestationConfiguration{
SNPPolicy: &check.Policy{
Measurement: make([]byte, measurementLength),
HostData: make([]byte, hostDataLength),
},
}
initialConfig := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
initialJSON, err := json.Marshal(initialConfig)
initialJSON, err := protojson.Marshal(&initialConfig)
require.NoError(t, err)
err = os.WriteFile(tmpfile.Name(), initialJSON, 0o644)
require.NoError(t, err)
@@ -91,15 +86,15 @@ func TestChangeAttestationConfiguration(t *testing.T) {
content, err := os.ReadFile(tmpfile.Name())
require.NoError(t, err)
var config AttestationConfiguration
err = json.Unmarshal(content, &config)
config := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
err = protojson.Unmarshal(content, &config)
require.NoError(t, err)
decodedData, _ := base64.StdEncoding.DecodeString(tt.base64Data)
if tt.field == measurementField {
assert.Equal(t, decodedData, config.SNPPolicy.Measurement)
assert.Equal(t, decodedData, config.Policy.Measurement)
} else if tt.field == hostDataField {
assert.Equal(t, decodedData, config.SNPPolicy.HostData)
assert.Equal(t, decodedData, config.Policy.HostData)
}
}
})
+5 -4
View File
@@ -9,6 +9,7 @@ import (
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/kds"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/verify/trust"
"github.com/spf13/cobra"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
@@ -26,14 +27,14 @@ func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command {
Example: "ca-bundle <path_to_platform_info_json>",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
attestationConfiguration := grpc.AttestationConfiguration{}
attestationConfiguration := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
err := grpc.ReadBackendInfo(args[0], &attestationConfiguration)
if err != nil {
printError(cmd, "Error while reading manifest: %v ❌ ", err)
return
}
product := attestationConfiguration.RootOfTrust.Product
product := attestationConfiguration.RootOfTrust.ProductLine
getter := trust.DefaultHTTPSGetter()
caURL := kds.ProductCertChainURL(abi.VcekReportSigner, product)
@@ -54,8 +55,8 @@ func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command {
return
}
bundleFilePath := path.Join(fileSavePath, product, caBundleName)
if err = saveToFile(bundleFilePath, bundle); err != nil {
bundlePath := path.Join(fileSavePath, product, caBundleName)
if err = saveToFile(bundlePath, bundle); err != nil {
printError(cmd, "Error while saving ARK-ASK to file: %v ❌ ", err)
return
}
+1 -1
View File
@@ -17,7 +17,7 @@ func TestNewCABundleCmd(t *testing.T) {
assert.NoError(t, err)
defer os.RemoveAll(tempDir)
manifestContent := []byte(`{"root_of_trust": {"product": "Milan"}}`)
manifestContent := []byte(`{"root_of_trust": {"product_line": "Milan"}}`)
manifestPath := path.Join(tempDir, "manifest.json")
err = os.WriteFile(manifestPath, manifestContent, 0o644)
assert.NoError(t, err)
+1 -1
View File
@@ -24,13 +24,13 @@ import (
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/agent/events"
"github.com/ultravioletrs/cocos/agent/quoteprovider"
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
managerevents "github.com/ultravioletrs/cocos/manager/events"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"golang.org/x/crypto/sha3"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
+1 -1
View File
@@ -12,7 +12,7 @@ import (
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/events/mocks"
qpmocks "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
qpmocks "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
)
func TestSetDefaultValues(t *testing.T) {
+26 -26
View File
@@ -11,7 +11,6 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"fmt"
"log/slog"
@@ -25,8 +24,8 @@ import (
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/internal/server"
"github.com/ultravioletrs/cocos/pkg/atls"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
@@ -45,6 +44,7 @@ const (
notAfterYear = 1
notAfterMonth = 0
notAfterDay = 0
nonceSize = 32
)
type Server struct {
@@ -89,15 +89,12 @@ func (s *Server) Start() error {
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
}
listener, err := net.Listen("tcp", s.Address)
if err != nil {
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
}
creds := grpc.Creds(insecure.NewCredentials())
var listener net.Listener = nil
switch {
case s.Config.AttestedTLS:
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.quoteProvider)
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS()
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}
@@ -113,7 +110,17 @@ func (s *Server) Start() error {
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
listener, err = atls.Listen(
s.Address,
certificateBytes,
privateKeyBytes,
)
if err != nil {
return fmt.Errorf("failed to create Listener for aTLS: %w", err)
}
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" || s.Config.KeyFile != "":
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)
if err != nil {
@@ -161,7 +168,18 @@ func (s *Server) Start() error {
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
}
listener, err = net.Listen("tcp", s.Address)
if err != nil {
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
}
default:
var err error
listener, err = net.Listen("tcp", s.Address)
if err != nil {
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
}
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
}
@@ -237,24 +255,13 @@ func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
return tls.X509KeyPair(cert, key)
}
func generateCertificatesForATLS(qp client.QuoteProvider) ([]byte, []byte, error) {
func generateCertificatesForATLS() ([]byte, []byte, error) {
curve := elliptic.P256()
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private/public key: %w", err)
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal the public key: %w", err)
}
// The Attestation Report will be added as an X.509 certificate extension
attestationReport, err := qp.GetRawQuote(sha3.Sum512(publicKeyBytes))
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch the attestation report: %w", err)
}
certTemplate := &x509.Certificate{
SerialNumber: big.NewInt(202403311),
Subject: pkix.Name{
@@ -270,13 +277,6 @@ func generateCertificatesForATLS(qp client.QuoteProvider) ([]byte, []byte, error
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
ExtraExtensions: []pkix.Extension{
{
Id: asn1.ObjectIdentifier{1, 2, 3, 4, 5, 6},
Critical: false,
Value: attestationReport,
},
},
}
certDERBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
+2 -4
View File
@@ -18,10 +18,9 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
authmocks "github.com/ultravioletrs/cocos/agent/mocks"
"github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
"github.com/ultravioletrs/cocos/internal/server"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
)
@@ -139,7 +138,6 @@ func TestServerStartWithAttestedTLS(t *testing.T) {
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.QuoteProvider)
authSvc := new(authmocks.Authenticator)
qp.On("GetRawQuote", mock.Anything).Return([]byte("mock-quote"), nil)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
@@ -158,7 +156,7 @@ func TestServerStartWithAttestedTLS(t *testing.T) {
cancel()
time.Sleep(100 * time.Millisecond)
time.Sleep(1000 * time.Millisecond)
logContent := logBuffer.String()
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with Attested TLS")
+8 -8
View File
@@ -9,22 +9,22 @@ package manager
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/google/go-sev-guest/proto/check"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/virtee/sev-snp-measure-go/cpuid"
"github.com/virtee/sev-snp-measure-go/guest"
"github.com/virtee/sev-snp-measure-go/vmmtypes"
"google.golang.org/protobuf/encoding/protojson"
)
const defGuestFeatures = 0x1
func (ms *managerService) FetchBackendInfo(_ context.Context, computationId string) ([]byte, error) {
cmd := exec.Command("sudo", fmt.Sprintf("%s/backend_info", ms.backendMeasurementBinaryPath), "--policy", "1966081")
cmd := exec.Command("sudo", fmt.Sprintf("%s/backend_info", ms.backendMeasurementBinaryPath), "--policy", "196608")
ms.mu.Lock()
vm, exists := ms.vms[computationId]
@@ -48,9 +48,9 @@ func (ms *managerService) FetchBackendInfo(_ context.Context, computationId stri
return nil, err
}
var backendInfo grpc.AttestationConfiguration
var backendInfo check.Config
if err = json.Unmarshal(f, &backendInfo); err != nil {
if err = protojson.Unmarshal(f, &backendInfo); err != nil {
return nil, err
}
@@ -68,7 +68,7 @@ func (ms *managerService) FetchBackendInfo(_ context.Context, computationId stri
}
}
if measurement == nil {
backendInfo.SNPPolicy.Measurement = measurement
backendInfo.Policy.Measurement = measurement
}
if config.HostData != "" {
@@ -76,10 +76,10 @@ func (ms *managerService) FetchBackendInfo(_ context.Context, computationId stri
if err != nil {
return nil, err
}
backendInfo.SNPPolicy.HostData = hostData
backendInfo.Policy.HostData = hostData
}
f, err = json.Marshal(backendInfo)
f, err = protojson.Marshal(&backendInfo)
if err != nil {
return nil, err
}
+1 -1
View File
@@ -20,7 +20,7 @@ func createDummyBackendInfoBinary(t *testing.T, behavior string) string {
switch behavior {
case "success":
content = []byte(`#!/bin/sh
echo '{"snp_policy": {"measurement": null, "host_data": null}}' > backend_info.json
echo '{"policy": {"measurement": null, "host_data": null}}' > backend_info.json
`)
case "fail":
content = []byte(`#!/bin/sh
+419
View File
@@ -0,0 +1,419 @@
// Code generated by cmd/cgo; DO NOT EDIT.
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
// #cgo LDFLAGS: -lssl -lcrypto
// #include "extensions.h"
import "C"
import (
"fmt"
"io"
"net"
"os"
"runtime/cgo"
"strconv"
"sync"
"syscall"
"time"
"unsafe"
"github.com/absmach/magistrala/pkg/errors"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
)
const (
NoTee int = iota
AmdSevSnp
)
const (
NO_ERROR = 0
ERROR_ZERO_RETURN = 6
ERROR_WANT_READ = 2
ERROR_WANT_WRITE = 3
ERROR_SYSCALL = 5
ERROR_SSL = 1
)
var (
errListener = errors.New("listener could not be created")
errBadIPFormat = errors.New("bad format of IP address")
errCloseTLS = errors.New("could not close TLS connection")
errConnFailed = errors.New("tls connection is nil")
errWrite = errors.New("could not write to TLS")
errTLSConn = errors.New("connection did not close correctly")
errReadDeadline = errors.New("could not set read deadline, socket timeout failed")
errWriteDeadline = errors.New("could not set write deadline, socket timeout failed")
errConnCreate = errors.New("could not create connection")
)
type ValidationVerification func(data1, data2 []byte) error
type FetchAttestation func(data1 []byte) ([]byte, error)
func registerFetchAttestation(callback FetchAttestation) uintptr {
handle := cgo.NewHandle(callback)
return uintptr(handle)
}
func registerValidationVerification(callback ValidationVerification) uintptr {
handle := cgo.NewHandle(callback)
return uintptr(handle)
}
//export validationVerificationCallback
func validationVerificationCallback(teeType C.int) uintptr {
switch int(teeType) {
case NoTee:
return uintptr(0)
case AmdSevSnp:
return registerValidationVerification(quoteprovider.VerifyAttestationReportTLS)
default:
return uintptr(0)
}
}
//export fetchAttestationCallback
func fetchAttestationCallback(teeType C.int) uintptr {
switch int(teeType) {
case NoTee:
return uintptr(0)
case AmdSevSnp:
return registerFetchAttestation(quoteprovider.FetchAttestation)
default:
return uintptr(0)
}
}
//export callVerificationValidationCallback
func callVerificationValidationCallback(callbackHandle uintptr, attReport *C.uchar, attReportSize C.int, repData *C.uchar) C.int {
handle := cgo.Handle(callbackHandle)
defer handle.Delete()
callback := handle.Value().(ValidationVerification)
attestationReport := C.GoBytes(unsafe.Pointer(attReport), attReportSize)
reportData := C.GoBytes(unsafe.Pointer(repData), agent.ReportDataSize)
err := callback(attestationReport, reportData)
if err != nil {
fmt.Fprintf(os.Stderr, "callback failed %v", err)
return C.int(-1)
}
return C.int(0)
}
//export callFetchAttestationCallback
func callFetchAttestationCallback(callbackHandle uintptr, reportDataByte *C.uchar, outlen *C.int) *C.uchar {
handle := cgo.Handle(callbackHandle)
defer handle.Delete()
callback := handle.Value().(FetchAttestation)
reportData := C.GoBytes(unsafe.Pointer(reportDataByte), agent.ReportDataSize)
quote, err := callback(reportData)
if err != nil {
fmt.Fprintf(os.Stderr, "attestation callback returned nil")
return nil
}
*outlen = C.int(len(quote))
resultC := C.malloc(C.size_t(len(quote)))
if resultC == nil {
fmt.Fprintf(os.Stderr, "could not allocate memory for fetch attestation callback")
return nil
}
C.memcpy(resultC, unsafe.Pointer(&quote[0]), C.size_t(len(quote)))
return (*C.uchar)(resultC)
}
type ATLSServerListener struct {
tlsListener *C.tls_server_connection
}
func Listen(addr string, cert []byte, key []byte) (net.Listener, error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, errors.Wrap(errListener, err)
}
p, err := strconv.Atoi(port)
if err != nil {
return nil, errors.Wrap(errBadIPFormat, err)
}
cCertPEM := (*C.char)(unsafe.Pointer(&cert[0]))
cKeyPEM := (*C.char)(unsafe.Pointer(&key[0]))
cIP := C.CString(ip)
defer C.free(unsafe.Pointer(cIP))
atlsListener := C.start_tls_server(
cCertPEM, C.int(len(cert)),
cKeyPEM, C.int(len(key)),
cIP, C.int(p))
if atlsListener == nil {
return nil, errors.Wrap(errListener, err)
}
return &ATLSServerListener{tlsListener: atlsListener}, nil
}
// accept implements the Accept method in the [Listener] interface; it
// waits for the next call and returns a generic [Conn].
func (l *ATLSServerListener) Accept() (net.Conn, error) {
conn := C.tls_server_accept(l.tlsListener)
if conn == nil {
return &ATLSConn{tlsConn: nil}, nil
}
return &ATLSConn{tlsConn: conn}, nil
}
// close stops listening on the TCP address.
// already Accepted connections are not closed.
func (l *ATLSServerListener) Close() error {
ret := C.tls_server_close(l.tlsListener)
if ret != 0 {
return errCloseTLS
}
return nil
}
// addr returns the listener's network address, a [*TCPAddr].
// the Addr returned is shared by all invocations of Addr, so
// do not modify it.
func (l *ATLSServerListener) Addr() net.Addr {
cIP := C.tls_return_addr(&l.tlsListener.addr)
defer C.free(unsafe.Pointer(cIP))
ip := C.GoString(cIP)
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return nil
}
port := C.tls_return_port(&l.tlsListener.addr)
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
}
type ATLSConn struct {
tlsConn *C.tls_connection
fdReadMutex sync.Mutex
fdWriteMutex sync.Mutex
fdDelayMutex sync.Mutex
}
func (c *ATLSConn) Read(b []byte) (int, error) {
c.fdReadMutex.Lock()
defer c.fdReadMutex.Unlock()
if c.tlsConn == nil {
return 0, errConnFailed
}
n := int(C.tls_read(c.tlsConn, unsafe.Pointer(&b[0]), C.int(len(b))))
if n > 0 {
return n, nil
}
// call the C function SSL_get_error to interpret the error.
errCode := int(C.SSL_get_error(c.tlsConn.ssl, C.int(n)))
// handle specific error codes returned by SSL_get_error.
switch errCode {
case NO_ERROR:
return n, nil // no error.
case ERROR_ZERO_RETURN:
fmt.Fprintf(os.Stderr, "Connection closed by peer")
return 0, io.EOF // connection closed.
case ERROR_WANT_READ:
fmt.Fprintf(os.Stderr, "Operation read incomplete, retry later")
return 0, nil // non-fatal, just retry later.
case ERROR_WANT_WRITE:
fmt.Fprintf(os.Stderr, "Operation write incomplete, retry later")
return 0, nil // non-fatal, just retry later.
case ERROR_SYSCALL:
fmt.Fprintf(os.Stderr, "I/O error")
return 0, syscall.ECONNRESET // return connection reset error.
case ERROR_SSL:
fmt.Fprintf(os.Stderr, "I/O error")
return 0, syscall.ECONNRESET // return connection reset error.
default:
fmt.Fprintf(os.Stderr, "SSL error occurred: %d\n", errCode)
return 0, fmt.Errorf("SSL error")
}
}
func (c *ATLSConn) Write(b []byte) (int, error) {
c.fdWriteMutex.Lock()
defer c.fdWriteMutex.Unlock()
if c.tlsConn == nil {
return 0, errConnFailed
}
n := int(C.tls_write(c.tlsConn, unsafe.Pointer(&b[0]), C.int(len(b))))
if n < 0 {
return 0, errWrite
}
return n, nil
}
func (c *ATLSConn) Close() error {
c.fdReadMutex.Lock()
defer c.fdReadMutex.Unlock()
c.fdWriteMutex.Lock()
defer c.fdWriteMutex.Unlock()
c.fdDelayMutex.Lock()
defer c.fdDelayMutex.Unlock()
if c.tlsConn == nil {
return nil
}
ret := C.tls_close(c.tlsConn)
if int(ret) < 0 {
c.tlsConn = nil
return errTLSConn
} else if int(ret) == 1 {
c.tlsConn = nil
}
return nil
}
func (c *ATLSConn) LocalAddr() net.Addr {
if c.tlsConn == nil {
return nil
}
cIP := C.tls_return_addr(&c.tlsConn.local_addr)
ipLength := C.strlen(cIP)
defer C.free(unsafe.Pointer(cIP))
ip := C.GoStringN(cIP, C.int(ipLength))
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
fmt.Println("Invalid IP address")
return nil
}
port := C.tls_return_port(&c.tlsConn.local_addr)
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
}
func (c *ATLSConn) RemoteAddr() net.Addr {
if c.tlsConn == nil {
return nil
}
cIP := C.tls_return_addr(&c.tlsConn.remote_addr)
if cIP == nil {
fmt.Println("RemoteAddr error while fetching ip")
return nil
}
ipLength := C.strlen(cIP)
defer C.free(unsafe.Pointer(cIP))
ip := C.GoStringN(cIP, C.int(ipLength))
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
fmt.Println("Invalid IP address")
return nil
}
port := C.tls_return_port(&c.tlsConn.remote_addr)
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
}
func (c *ATLSConn) SetDeadline(t time.Time) error {
c.fdDelayMutex.Lock()
defer c.fdDelayMutex.Unlock()
if c.tlsConn == nil {
return nil
}
sec, usec := timeToTimeout(t)
if C.set_socket_read_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
return errReadDeadline
}
if C.set_socket_write_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
return errWriteDeadline
}
return nil
}
func (c *ATLSConn) SetReadDeadline(t time.Time) error {
c.fdDelayMutex.Lock()
defer c.fdDelayMutex.Unlock()
if c.tlsConn == nil {
return nil
}
sec, usec := timeToTimeout(t)
if C.set_socket_read_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
return errReadDeadline
}
return nil
}
func (c *ATLSConn) SetWriteDeadline(t time.Time) error {
c.fdDelayMutex.Lock()
defer c.fdDelayMutex.Unlock()
if c.tlsConn == nil {
return nil
}
sec, usec := timeToTimeout(t)
if C.set_socket_write_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
return errWriteDeadline
}
return nil
}
func DialTLSClient(hostname string, port int) (net.Conn, error) {
cHostName := C.CString(hostname)
defer C.free(unsafe.Pointer(cHostName))
conn := C.new_tls_connection(cHostName, C.int(port))
if conn == nil {
return nil, errConnCreate
}
return &ATLSConn{tlsConn: conn}, nil
}
func timeToTimeout(t time.Time) (int, int) {
if t.IsZero() {
return 0, 0
}
d := time.Until(t)
if d <= 0 {
return 0, 0
}
seconds := int(d.Seconds())
microseconds := int(d.Nanoseconds()/1000) % 1_000_000
return seconds, microseconds
}
+380
View File
@@ -0,0 +1,380 @@
#include "extensions.h"
#include <stdio.h>
#include <string.h>
#include <openssl/sha.h>
#include <openssl/rand.h>
#include <openssl/x509.h>
#include <fcntl.h>
#include <unistd.h>
extern int callVerificationValidationCallback(uintptr_t callbackHandle, const u_char* attReport, int attReportSize, const u_char* repData);
extern u_char* callFetchAttestationCallback(uintptr_t callbackHandle, const u_char* reportDataByte, int* outlen);
extern uintptr_t validationVerificationCallback(int teeType);
extern uintptr_t fetchAttestationCallback(int teeType);
int triggerVerificationValidationCallback(uintptr_t callbackHandle, u_char *attestationReport, int reportSize, u_char *reportData) {
if (attestationReport == NULL || reportData == NULL) {
fprintf(stderr, "attestation data and report data cannot be NULL\n");
return -1;
}
return callVerificationValidationCallback(callbackHandle, attestationReport, reportSize, reportData);
}
u_char* triggerFetchAttestationCallback(uintptr_t callbackHandle, char *reportData) {
int outlen = REPORT_DATA_SIZE;
if(reportData == NULL) {
fprintf(stderr, "Report data cannot be NULL");
return NULL;
}
return callFetchAttestationCallback(callbackHandle, reportData, &outlen);
}
int check_sev_snp() {
int fd = open(SEV_GUEST_DRIVER_PATH, O_RDONLY);
if (fd == -1) {
perror("Error opening /dev/sev-guest");
fprintf(stderr, "SEV guest driver is not available.\n");
return -1;
} else {
close(fd);
}
return 1;
}
int compute_sha256_of_public_key_nonce(X509 *cert, u_char *nonce, u_char *hash) {
EVP_PKEY *pkey = NULL;
u_char *pubkey_buf = NULL;
u_char *concatinated = NULL;
int pubkey_len = 0;
int totla_len = 0;
pkey = X509_get_pubkey(cert);
if (pkey == NULL) {
fprintf(stderr, "Failed to extract public key from certificate\n");
return 0;
}
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
if (pubkey_len <= 0) {
fprintf(stderr, "Failed to convert public key to DER format\n");
EVP_PKEY_free(pkey);
return -1;
}
totla_len = pubkey_len + CLIENT_RANDOM_SIZE;
concatinated = (u_char*)malloc(totla_len);
if (concatinated == NULL) {
perror("failed to allocate memory");
return -1;
}
memcpy(concatinated, nonce, CLIENT_RANDOM_SIZE);
memcpy(concatinated + CLIENT_RANDOM_SIZE, pubkey_buf, pubkey_len);
// Compute the SHA-512 hash of the DER-encoded public key and the random nonce
SHA512(concatinated, totla_len, hash);
// Clean up
EVP_PKEY_free(pkey);
OPENSSL_free(pubkey_buf);
free(concatinated);
return 0; // Success
}
/*
Evidence request extension
- Contains a random nonce that goes into the attestation report
- Is sent in the ClientHello message
*/
void evidence_request_ext_free_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char *out,
void *add_arg)
{
free((void *)out);
}
int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char **out,
size_t *outlen, X509 *x,
size_t chainidx, int *al,
void *add_arg)
{
switch (context)
{
case SSL_EXT_CLIENT_HELLO:
{
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
evidence_request *er = (evidence_request*)malloc(sizeof(evidence_request));
if (er == NULL) {
perror("could not allocate memory");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
if (ext_data != NULL) {
if (RAND_bytes(ext_data->er.data, CLIENT_RANDOM_SIZE) != 1) {
perror("could not generate random bytes, will use SSL client random");
SSL_get_client_random(s, ext_data->er.data, CLIENT_RANDOM_SIZE);
}
} else {
fprintf(stderr, "add_arg is NULL\n");
free(er);
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
memcpy(er->data, ext_data->er.data, CLIENT_RANDOM_SIZE);
er->tee_type = AMD_TEE;
ext_data->er.tee_type = AMD_TEE;
*out = (const u_char *)er;
*outlen = sizeof(evidence_request);
return 1;
}
case SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS:
{
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
if (ext_data != NULL) {
int32_t *platform_type = (int32_t*)malloc(sizeof(int32_t));
if (platform_type == NULL) {
perror("could not allocate memory");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
if (check_sev_snp() > 0) {
*platform_type = AMD_TEE;
} else {
*platform_type = NO_TEE;
}
if ((*platform_type != ext_data->er.tee_type) || (*platform_type == NO_TEE)) {
*platform_type = NO_TEE;
ext_data->er.tee_type = NO_TEE;
} else {
ext_data->er.tee_type = AMD_TEE;
ext_data->fetch_attestation_handler = fetchAttestationCallback(ext_data->er.tee_type);
}
*out = (u_char*)platform_type;
*outlen = sizeof(int32_t);
} else {
fprintf(stderr, "add_arg is NULL\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
return 1;
}
default:
break;
}
fprintf(stderr, "bad context\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char *in,
size_t inlen, X509 *x,
size_t chainidx, int *al,
void *parse_arg)
{
switch (context)
{
case SSL_EXT_CLIENT_HELLO:
{
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
evidence_request *er = (evidence_request*)in;
if (ext_data != NULL) {
memcpy(ext_data->er.data, er->data, CLIENT_RANDOM_SIZE);
ext_data->er.tee_type = er->tee_type;
} else {
fprintf(stderr, "parse_arg is NULL\n");
return 0;
}
return 1;
}
case SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS:
{
int *tee_type = (int*)in;
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
if (ext_data != NULL) {
ext_data->er.tee_type = *tee_type;
if (ext_data->er.tee_type != NO_TEE) {
ext_data->verification_validation_handler = validationVerificationCallback(ext_data->er.tee_type);
} else {
fprintf(stderr, "must use a TEE for aTLS\n");
return 0;
}
} else {
fprintf(stderr, "parse_arg is NULL\n");
return 0;
}
return 1;
}
default:
fprintf(stderr, "bad context\n");
return 0;
}
}
/*
Attestation Certificate extension
- Contains the attestation report
- The attestation report contains the hash of the nonce and the Public Key of the x.509 Agent certificate
*/
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char *out,
void *add_arg)
{
free((void *)out);
}
int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char **out,
size_t *outlen, X509 *x,
size_t chainidx, int *al,
void *add_arg)
{
switch (context)
{
case SSL_EXT_CLIENT_HELLO:
return 1;
case SSL_EXT_TLS1_3_CERTIFICATE:
{
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
if (ext_data != NULL) {
u_char *attestation_report;
u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char));
if (hash == NULL) {
perror("could not allocate memory");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
if (x != NULL) {
int ret = compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash);
if (ret != 0) {
fprintf(stderr, "error while calculating hash\n");
free(hash);
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
} else {
fprintf(stderr, "agent certificate must be used for aTLS\n");
free(hash);
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
attestation_report = triggerFetchAttestationCallback(ext_data->fetch_attestation_handler, hash);
if (attestation_report == NULL) {
fprintf(stderr, "attestation report is NULL\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
free(hash);
*out = attestation_report;
*outlen = ATTESTATION_REPORT_SIZE;
return 1;
} else {
fprintf(stderr, "add_arg is NULL\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
}
default:
fprintf(stderr, "bad context\n");
*al = SSL_AD_INTERNAL_ERROR;
return -1;
}
}
int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const u_char *in,
size_t inlen, X509 *x,
size_t chainidx, int *al,
void *parse_arg)
{
switch (context)
{
case SSL_EXT_CLIENT_HELLO:
// Return 1 so the server can return the custom certificate extension.
return 1;
case SSL_EXT_TLS1_3_CERTIFICATE:
{
if (x != NULL) {
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
if (ext_data != NULL) {
char *attestation_report = (char*)malloc(ATTESTATION_REPORT_SIZE*sizeof(char));
u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char));
int res = 0;
if (hash == NULL || attestation_report == NULL) {
perror("could not allocate memory");
if (hash != NULL) free(hash);
if (attestation_report != NULL) free(attestation_report);
return 0;
}
if (compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash) != 0) {
fprintf(stderr, "calculating hash failed\n");
free(attestation_report);
free(hash);
return 0;
}
memcpy(attestation_report, in, inlen);
res = triggerVerificationValidationCallback(ext_data->verification_validation_handler,
attestation_report,
ATTESTATION_REPORT_SIZE,
hash);
free(attestation_report);
free(hash);
if (res != 0) {
fprintf(stderr, "verification and validation failed, aborting connection\n");
return 0;
}
} else {
fprintf(stderr, "parse_arg is NULL\n");
return 0;
}
return 1;
} else {
fprintf(stderr, "agent certificates must be used for aTLS\n");
return 0;
}
}
default:
fprintf(stderr, "bad context\n");
return 0;
}
}
+100
View File
@@ -0,0 +1,100 @@
#ifndef ATLS_EXTENSION_H
#define ATLS_EXTENSION_H
#include <openssl/ssl.h>
#include <arpa/inet.h>
#define EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE 65
#define ATTESTATION_CERTIFICATE_EXTENSION_TYPE 66
#define ATTESTATION_REPORT_SIZE 0x4A0
#define REPORT_DATA_SIZE 64
#define CLIENT_RANDOM_SIZE 32
#define TLS_CLIENT_CTX 0
#define TLS_SERVER_CTX 1
#define SEV_GUEST_DRIVER_PATH "/dev/sev-guest"
#define NO_TEE 0
#define AMD_TEE 1
typedef struct evidence_request
{
int tee_type;
char data[CLIENT_RANDOM_SIZE];
} evidence_request;
typedef struct tls_extension_data
{
uintptr_t fetch_attestation_handler;
uintptr_t verification_validation_handler;
evidence_request er;
} tls_extension_data;
typedef struct tls_server_connection
{
int server_fd;
char* cert;
int cert_len;
char* key;
int key_len;
struct sockaddr_storage addr;
uintptr_t fetch_attestation_handler;
} tls_server_connection;
typedef struct tls_connection
{
SSL_CTX *ctx;
SSL *ssl;
int socket_fd;
struct sockaddr_storage local_addr;
struct sockaddr_storage remote_addr;
tls_extension_data tls_ext_data;
} tls_connection;
tls_server_connection* start_tls_server(const char* cert, int cert_len, const char* key, int key_len, const char* ip, int port);
tls_connection* tls_server_accept(tls_server_connection *tls_server);
int tls_server_close(tls_server_connection *tls_server);
int tls_read(tls_connection *conn, void *buf, int num);
int tls_write(tls_connection *conn, const void *buf, int num);
int tls_close(tls_connection *conn);
tls_connection* new_tls_connection(char *address, int port);
int set_socket_read_timeout(tls_connection* conn, int timeout_sec, int timeout_usec);
int set_socket_write_timeout(tls_connection* conn, int timeout_sec, int timeout_usec);
char* tls_return_addr(struct sockaddr_storage *addr);
int tls_return_port(struct sockaddr_storage *addr);
int compute_sha256_of_public_key(X509 *cert, unsigned char *hash);
// Extensions
void evidence_request_ext_free_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char *out,
void *add_arg);
int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char **out,
size_t *outlen, X509 *x,
size_t chainidx, int *al,
void *add_arg);
int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char *in,
size_t inlen, X509 *x,
size_t chainidx, int *al,
void *parse_arg);
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char *out,
void *add_arg);
int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char **out,
size_t *outlen, X509 *x,
size_t chainidx, int *al,
void *add_arg);
int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
unsigned int context,
const unsigned char *in,
size_t inlen, X509 *x,
size_t chainidx, int *al,
void *parse_arg);
#endif // ATLS_EXTENSION_H
+586
View File
@@ -0,0 +1,586 @@
#include "extensions.h"
#include <openssl/err.h>
#include <openssl/x509.h>
#include <openssl/evp.h>
#include <stdio.h>
#include <netdb.h>
#include <unistd.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <ifaddrs.h>
void init_openssl() {
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
}
void cleanup_openssl() {
EVP_cleanup();
}
int load_certificates_from_memory(SSL_CTX* ctx, const char* cert, int cert_len, const char* key, int key_len) {
BIO* cert_bio = NULL;
BIO* key_bio = NULL;
X509* x509_cert = NULL;
EVP_PKEY* pkey = NULL;
int success = 0;
cert_bio = BIO_new_mem_buf(cert, cert_len);
if (cert_bio == NULL) {
goto cleanup;
}
key_bio = BIO_new_mem_buf(key, key_len);
if (key_bio == NULL) {
goto cleanup;
}
x509_cert = PEM_read_bio_X509(cert_bio, NULL, 0, NULL);
if (x509_cert == NULL) {
goto cleanup;
}
pkey = PEM_read_bio_PrivateKey(key_bio, NULL, 0, NULL);
if (pkey == NULL) {
goto cleanup;
}
if (SSL_CTX_use_certificate(ctx, x509_cert) <= 0) {
goto cleanup;
}
if (SSL_CTX_use_PrivateKey(ctx, pkey) <= 0) {
goto cleanup;
}
if (SSL_CTX_check_private_key(ctx) <= 0) {
goto cleanup;
}
success = 1;
cleanup:
if (cert_bio) BIO_free(cert_bio);
if (key_bio) BIO_free(key_bio);
if (x509_cert) X509_free(x509_cert);
if (pkey) EVP_PKEY_free(pkey);
if (!success) {
ERR_print_errors_fp(stderr);
}
return success;
}
int enforce_tls1_3_only(SSL_CTX *ctx) {
if (SSL_CTX_set_min_proto_version(ctx, TLS1_3_VERSION) == 0) {
return 0;
}
if (SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION) == 0) {
return 0;
}
return 1;
}
SSL_CTX *create_context(int is_server) {
const SSL_METHOD *method;
SSL_CTX *ctx;
if (is_server) {
method = TLS_server_method();
} else {
method = TLS_client_method();
}
ctx = SSL_CTX_new(method);
if (!ctx) {
perror("Unable to create SSL context");
ERR_print_errors_fp(stderr);
return NULL;
}
if (!enforce_tls1_3_only(ctx)) {
fprintf(stderr, "could not enforce TLS1.3\n");
SSL_CTX_free(ctx);
return NULL;
}
return ctx;
}
int add_custom_tls_extension(SSL_CTX *ctx, tls_connection *conn) {
uint32_t flags_nonce = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS;
uint32_t flags_attestation = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_CERTIFICATE;
int ret = 1;
void *data = NULL;
if (conn != NULL) {
data = (void*)&conn->tls_ext_data;
}
ret = SSL_CTX_add_custom_ext(ctx,
EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE,
flags_nonce,
evidence_request_ext_add_cb,
evidence_request_ext_free_cb,
data,
evidence_request_ext_parse_cb,
data);
if (ret != 1) {
return 0;
}
ret = SSL_CTX_add_custom_ext(ctx,
ATTESTATION_CERTIFICATE_EXTENSION_TYPE,
flags_attestation,
attestation_certificate_ext_add_cb,
attestation_certificate_ext_free_cb,
data,
attestation_certificate_ext_parse_cb,
data);
if (ret != 1) {
return 0;
}
return ret;
}
// Function to start the TLS server
tls_server_connection* start_tls_server(const char* cert, int cert_len, const char* key, int key_len, const char* ip, int port) {
tls_server_connection *tls_server = (tls_server_connection*)malloc(sizeof(tls_server_connection));
int opt = 0;
if (tls_server == NULL) {
perror("memory could not be allocated");
return NULL;
}
init_openssl();
tls_server->cert = (char*)malloc(cert_len * sizeof(char));
if (tls_server->cert == NULL) {
perror("memory could not be allocated");
goto cleanup_tls_server;
}
tls_server->key = (char*)malloc(key_len * sizeof(char));
if (tls_server->key == NULL) {
perror("memory could not be allocated");
goto cleanup_cert;
}
memcpy(tls_server->cert, cert, cert_len);
memcpy(tls_server->key, key, key_len);
tls_server->cert_len = cert_len;
tls_server->key_len = key_len;
tls_server->server_fd = socket(AF_INET6, SOCK_STREAM, 0);
if (tls_server->server_fd < 0) {
perror("Unable to create socket");
goto cleanup_key;
}
// Enable both IPv4-mapped and IPv6 addresses
if (setsockopt(tls_server->server_fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) != 0) {
perror("setsockopt(IPV6_V6ONLY) failed");
goto cleanup_socket;
}
// Configure address structure
memset(&(tls_server->addr), 0, sizeof(tls_server->addr));
struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)&tls_server->addr;
addr6->sin6_family = AF_INET6;
addr6->sin6_port = htons(port);
// Set the appropriate address (IPv4-mapped if needed)
if (ip == NULL || (strlen(ip) + 1) < INET_ADDRSTRLEN) {
addr6->sin6_addr = in6addr_any;
} else if (strchr(ip, ':') != NULL) {
if (inet_pton(AF_INET6, ip, &(addr6->sin6_addr)) <= 0) {
perror("Invalid IPv6 address");
goto cleanup_socket;
}
} else {
struct in_addr ipv4_addr;
if (inet_pton(AF_INET, ip, &ipv4_addr) <= 0) {
perror("Invalid IPv4 address");
goto cleanup_socket;
}
memset(&addr6->sin6_addr, 0, sizeof(addr6->sin6_addr));
addr6->sin6_addr.s6_addr[10] = 0xff;
addr6->sin6_addr.s6_addr[11] = 0xff;
memcpy(&addr6->sin6_addr.s6_addr[12], &ipv4_addr, sizeof(ipv4_addr));
}
if (bind(tls_server->server_fd, (struct sockaddr*)&(tls_server->addr), sizeof(tls_server->addr)) < 0) {
perror("Unable to bind");
goto cleanup_socket;
}
if (listen(tls_server->server_fd, SOMAXCONN) < 0) {
perror("Unable to listen");
goto cleanup_socket;
}
printf("Listening on port: %d\n", port);
return tls_server;
// Cleanup labels
cleanup_socket:
close(tls_server->server_fd);
cleanup_key:
free(tls_server->key);
cleanup_cert:
free(tls_server->cert);
cleanup_tls_server:
free(tls_server);
return NULL;
}
// Function to accept a client connection
tls_connection* tls_server_accept(tls_server_connection *tls_server) {
uint32_t len = sizeof(struct sockaddr_storage);
tls_connection *conn = (tls_connection*)malloc(sizeof(tls_connection));
int client_fd = -1;
int ret = 0;
if (conn == NULL) {
perror("Unable to allocate memory for tls_connection");
return NULL;
}
conn->ctx = NULL;
conn->ssl = NULL;
conn->socket_fd = -1;
// Initialize the context
conn->ctx = create_context(TLS_SERVER_CTX);
if (conn->ctx == NULL) {
perror("Unable to create context");
goto cleanup_conn;
}
// Load certificates
if (!load_certificates_from_memory(conn->ctx, tls_server->cert, tls_server->cert_len, tls_server->key, tls_server->key_len)) {
fprintf(stderr, "Failed to load certificates\n");
goto cleanup_ctx;
}
// Add custom TLS extension
ret = add_custom_tls_extension(conn->ctx, conn);
if (!ret) {
perror("Unable to add custom tls extensions");
goto cleanup_ctx;
}
// Accept client connection
client_fd = accept(tls_server->server_fd, NULL, NULL);
if (client_fd < 0) {
perror("Unable to accept connection");
goto cleanup_ctx;
}
// Create SSL object
conn->ssl = SSL_new(conn->ctx);
if (conn->ssl == NULL) {
perror("Unable to create SSL object");
goto cleanup_fd;
}
// Set file descriptor and assign handlers
conn->socket_fd = client_fd;
conn->tls_ext_data.fetch_attestation_handler = tls_server->fetch_attestation_handler;
SSL_set_fd(conn->ssl, client_fd);
// Get local address
if (getsockname(client_fd, (struct sockaddr *)&conn->local_addr, &len) == -1) {
perror("getsockname failed during TLS server accept");
goto cleanup_ssl;
}
// Get remote address
if (getpeername(client_fd, (struct sockaddr *)&conn->remote_addr, &len) == -1) {
perror("getpeername failed during TLS server accept");
goto cleanup_ssl;
}
// Perform SSL handshake
ret = SSL_accept(conn->ssl);
if (ret <= 0) {
perror("SSL handshake failed during accept");
goto cleanup_ssl;
}
return conn;
cleanup_ssl:
if (conn->ssl) SSL_free(conn->ssl);
cleanup_fd:
if (client_fd >= 0) close(client_fd);
cleanup_ctx:
if (conn->ctx) SSL_CTX_free(conn->ctx);
cleanup_conn:
free(conn);
return NULL;
}
// Function to close the server
int tls_server_close(tls_server_connection *tls_server) {
close(tls_server->server_fd);
cleanup_openssl();
free(tls_server->cert);
free(tls_server->key);
free(tls_server);
return 0;
}
int tls_read(tls_connection *conn, void *buf, int num) {
return SSL_read(conn->ssl, buf, num);
}
int tls_write(tls_connection *conn, const void *buf, int num) {
if (SSL_get_shutdown(conn->ssl) & SSL_SENT_SHUTDOWN) {
return 0;
}
return SSL_write(conn->ssl, buf, num);
}
int tls_close(tls_connection *conn) {
if (conn != NULL) {
if (conn->ssl != NULL) {
int ret = 0;
while (ret == 0) {
ret = SSL_shutdown(conn->ssl);
if (ret < 0) {
fprintf(stderr, "SSL did not shutdown correctly\n");
free(conn);
close(conn->socket_fd);
conn = NULL;
return -1;
}
}
conn->ssl = NULL;
}
if (conn->socket_fd >= 0) {
close(conn->socket_fd);
conn->socket_fd = -1;
}
SSL_free(conn->ssl);
if (conn->ctx != NULL) {
SSL_CTX_free(conn->ctx);
conn->ctx = NULL;
}
free(conn);
conn = NULL;
return 1;
}
return 0;
}
char* tls_return_addr(struct sockaddr_storage *addr) {
socklen_t addr_len = sizeof(struct sockaddr_storage);
int inet_size =addr->ss_family == AF_INET ? INET_ADDRSTRLEN : INET6_ADDRSTRLEN;
char *ip_str = (char*)malloc(inet_size*sizeof(char));
void * addr_ptr;
if (ip_str == NULL) {
perror("memory could not be allocated");
return NULL;
}
if (addr->ss_family == AF_INET) {
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
addr_ptr = &(ipv4->sin_addr);
} else if (addr->ss_family == AF_INET6) {
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
addr_ptr = &(ipv6->sin6_addr);
} else {
fprintf(stderr, "unknown family: %d\n", addr->ss_family);
free(ip_str);
return NULL;
}
if (inet_ntop(addr->ss_family, addr_ptr, ip_str, inet_size) == NULL) {
perror("inet_ntop failed");
free(ip_str);
return NULL;
}
return ip_str;
}
int tls_return_port(struct sockaddr_storage *addr) {
if (addr->ss_family == AF_INET) {
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
return ntohs(ipv4->sin_port);
} else if (addr->ss_family == AF_INET6) {
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
return ntohs(ipv6->sin6_port);
}
fprintf(stderr, "cannot return port from unknown family: %d\n", addr->ss_family);
return -1;
}
tls_connection* new_tls_connection(char *address, int port) {
SSL_CTX *ctx = NULL;
SSL *ssl = NULL;
int socket_fd = -1;
int status;
struct addrinfo hints, *res = NULL, *p = NULL;
char port_str[6];
tls_connection *conn = NULL;
socklen_t addr_len;
int ret = 0;
conn = (tls_connection*)malloc(sizeof(tls_connection));
if (!conn) {
perror("Failed to allocate memory for atls connection");
return NULL;
}
// Format the port string
snprintf(port_str, sizeof(port_str), "%d", port);
// Initialize OpenSSL
init_openssl();
// Create SSL context
ctx = create_context(TLS_CLIENT_CTX);
if (!ctx) {
perror("Could not create context");
goto cleanup_conn;
}
conn->ctx = ctx;
// Add custom TLS extension
ret = add_custom_tls_extension(conn->ctx, conn);
if (!ret) {
perror("Unable to add custom tls extensions");
goto cleanup_ctx;
}
// Prepare the address info hints
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = 0;
hints.ai_protocol = 0;
// Get address info
status = getaddrinfo(address, port_str, &hints, &res);
if (status != 0) {
perror("getaddrinfo error");
goto cleanup_ctx;
}
// Iterate through the results and try to connect
for (p = res; p != NULL; p = p->ai_next) {
socket_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (socket_fd < 0) {
perror("unable to create socket");
continue;
}
if (connect(socket_fd, p->ai_addr, p->ai_addrlen)) {
close(socket_fd);
continue;
}
memcpy(&conn->local_addr, p->ai_addr, p->ai_addrlen);
break;
}
freeaddrinfo(res);
if (p == NULL) {
goto cleanup_ctx;
}
conn->socket_fd = socket_fd;
// Retrieve and store the remote address
addr_len = sizeof(conn->remote_addr);
if (getpeername(socket_fd, (struct sockaddr *)&conn->remote_addr, &addr_len) == -1) {
perror("getpeername failed");
goto cleanup_socket;
}
// Create the SSL structure
ssl = SSL_new(ctx);
if (!ssl) {
perror("Failed to create SSL object");
goto cleanup_socket;
}
if (!SSL_set_fd(ssl, socket_fd)) {
perror("Failed to set SSL file descriptor");
goto cleanup_ssl;
}
conn->ssl = ssl;
// Perform the SSL handshake
if (SSL_connect(ssl) <= 0) {
fprintf(stderr, "SSL handshake failed\n");
goto cleanup_ssl;
}
return conn;
cleanup_ssl:
if (ssl) SSL_free(ssl);
cleanup_socket:
if (socket_fd >= 0) close(socket_fd);
cleanup_ctx:
if (ctx) SSL_CTX_free(ctx);
cleanup_conn:
free(conn);
return NULL;
}
int set_socket_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
struct timeval timeout;
timeout.tv_sec = timeout_sec;
timeout.tv_usec = timeout_usec;
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
return -1;
}
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
return -1;
}
return 0;
}
int set_socket_read_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
struct timeval timeout;
timeout.tv_sec = timeout_sec;
timeout.tv_usec = timeout_usec;
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
return -1;
}
return 0;
}
int set_socket_write_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
struct timeval timeout;
timeout.tv_sec = timeout_sec;
timeout.tv_usec = timeout_usec;
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
return -1;
}
return 0;
}
+198
View File
@@ -0,0 +1,198 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
//go:build !embed
// +build !embed
package quoteprovider
import (
"fmt"
"io"
"os"
"path"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/client"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/validate"
"github.com/google/go-sev-guest/verify"
"github.com/google/go-sev-guest/verify/trust"
"github.com/google/logger"
"google.golang.org/protobuf/proto"
)
const (
cocosDirectory = ".cocos"
caBundleName = "ask_ark.pem"
attestationReportSize = 0x4A0
reportDataSize = 64
sevProductNameMilan = "Milan"
sevProductNameGenoa = "Genoa"
)
var (
AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
timeout = time.Minute * 2
maxTryDelay = time.Second * 30
)
var (
errProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevProductNameMilan, sevProductNameGenoa))
errReportSize = errors.New("attestation report size mismatch")
errAttVerification = errors.New("attestation verification failed")
errAttValidation = errors.New("attestation validation failed")
)
func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) error {
product := cfg.RootOfTrust.ProductLine
chain := attestation.GetCertificateChain()
if chain == nil {
chain = &sevsnp.CertificateChain{}
attestation.CertificateChain = chain
}
if len(chain.GetAskCert()) == 0 || len(chain.GetArkCert()) == 0 {
homePath, err := os.UserHomeDir()
if err != nil {
return err
}
bundlePath := path.Join(homePath, cocosDirectory, product, caBundleName)
if _, err := os.Stat(bundlePath); err == nil {
amdRootCerts := trust.AMDRootCerts{}
if err := amdRootCerts.FromKDSCert(bundlePath); err != nil {
return err
}
chain.ArkCert = amdRootCerts.ProductCerts.Ark.Raw
chain.AskCert = amdRootCerts.ProductCerts.Ask.Raw
}
}
return nil
}
func copyConfig(attConf *check.Config) (*check.Config, error) {
copy := proto.Clone(attConf).(*check.Config)
return copy, nil
}
func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust)
if err != nil {
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(errAttVerification, err))
}
if cfg.Policy.Product == nil {
productName := sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN
switch cfg.RootOfTrust.ProductLine {
case sevProductNameMilan:
productName = sevsnp.SevProduct_SEV_PRODUCT_MILAN
case sevProductNameGenoa:
productName = sevsnp.SevProduct_SEV_PRODUCT_GENOA
default:
}
if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN {
return errProductLine
}
sopts.Product = &sevsnp.SevProduct{
Name: productName,
}
} else {
sopts.Product = cfg.Policy.Product
}
sopts.Getter = &trust.RetryHTTPSGetter{
Timeout: timeout,
MaxRetryDelay: maxTryDelay,
Getter: &trust.SimpleHTTPSGetter{},
}
if err := fillInAttestationLocal(attestationPB, cfg); err != nil {
return fmt.Errorf("failed to fill the attestation with local ARK and ASK certificates %v", err)
}
if err := verify.SnpAttestation(attestationPB, sopts); err != nil {
return errors.Wrap(errAttVerification, err)
}
return nil
}
func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
opts, err := validate.PolicyToOptions(cfg.Policy)
if err != nil {
return fmt.Errorf("failed to get policy for validation %v", errors.Wrap(errAttVerification, err))
}
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
return errors.Wrap(errAttValidation, err)
}
return nil
}
func GetQuoteProvider() (client.QuoteProvider, error) {
return client.GetQuoteProvider()
}
func VerifyAttestationReportTLS(attestationBytes []byte, reportData []byte) error {
config, err := copyConfig(&AttConfigurationSEVSNP)
if err != nil {
return fmt.Errorf("failed to create a copy of backend configuration")
}
config.Policy.ReportData = reportData[:]
return VerifyAndValidate(attestationBytes, config)
}
func VerifyAndValidate(attestationReport []byte, cfg *check.Config) error {
logger.Init("", false, false, io.Discard)
if len(attestationReport) < attestationReportSize {
return errReportSize
}
attestationBytes := attestationReport[:attestationReportSize]
attestationPB, err := abi.ReportCertsToProto(attestationBytes)
if err != nil {
return fmt.Errorf("failed to convert attestation bytes to struct %v", errors.Wrap(errAttVerification, err))
}
if err = verifyReport(attestationPB, cfg); err != nil {
return err
}
if err = validateReport(attestationPB, cfg); err != nil {
return err
}
return nil
}
func FetchAttestation(reportDataSlice []byte) ([]byte, error) {
var reportData [reportDataSize]byte
qp, err := GetQuoteProvider()
if err != nil {
return []byte{}, fmt.Errorf("could not get quote provider")
}
if len(reportData) > reportDataSize {
return []byte{}, fmt.Errorf("attestation report size mismatch")
}
copy(reportData[:], reportDataSlice)
rawQuote, err := qp.GetRawQuote(reportData)
if err != nil {
return []byte{}, fmt.Errorf("failed to get raw quote")
}
return rawQuote, nil
}
+216
View File
@@ -0,0 +1,216 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
//go:build !embed
// +build !embed
package quoteprovider
import (
"fmt"
"os"
"testing"
"github.com/absmach/magistrala/pkg/errors"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protojson"
)
const (
measurementOffset = 0x90
signatureOffset = 0x2A0
)
func TestFillInAttestationLocal(t *testing.T) {
tempDir, err := os.MkdirTemp("", "test_home")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
cocosDir := tempDir + "/.cocos/Milan"
err = os.MkdirAll(cocosDir, 0o755)
require.NoError(t, err)
bundleContent := []byte("mock ASK ARK bundle")
err = os.WriteFile(cocosDir+"/ask_ark.pem", bundleContent, 0o644)
require.NoError(t, err)
config := check.Config{
RootOfTrust: &check.RootOfTrust{},
Policy: &check.Policy{},
}
tests := []struct {
name string
attestation *sevsnp.Attestation
err error
}{
{
name: "Empty attestation",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
err: nil,
},
{
name: "Attestation with existing chain",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("existing ASK cert"),
ArkCert: []byte("existing ARK cert"),
},
},
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := fillInAttestationLocal(tt.attestation, &config)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
}
func TestVerifyAttestationReportSuccess(t *testing.T) {
file, reportData := prepareForTestVerifyAttestationReport(t)
tests := []struct {
name string
attestationReport []byte
reportData []byte
goodProduct int
err error
}{
{
name: "Valid attestation, validation and verification is performed succsessfully",
attestationReport: file,
reportData: reportData,
goodProduct: 1,
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
}
func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
file, reportData := prepareForTestVerifyAttestationReport(t)
// Change random data so in the signature so the signature failes
file[signatureOffset] = file[signatureOffset] ^ 0x01
tests := []struct {
name string
attestationReport []byte
reportData []byte
err error
}{
{
name: "Valid attestation, distorted signature",
attestationReport: file,
reportData: reportData,
err: errAttVerification,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
}
func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
file, reportData := prepareForTestVerifyAttestationReport(t)
tests := []struct {
name string
attestationReport []byte
reportData []byte
err error
}{
{
name: "Valid attestation, unknown product",
attestationReport: file,
reportData: reportData,
err: errProductLine,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AttConfigurationSEVSNP.RootOfTrust.ProductLine = ""
AttConfigurationSEVSNP.Policy.Product = nil
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
}
func TestVerifyAttestationReportMalformedPolicy(t *testing.T) {
file, reportData := prepareForTestVerifyAttestationReport(t)
// Change random data in the measurement so the measurement does not match
file[measurementOffset] = file[measurementOffset] ^ 0x01
tests := []struct {
name string
attestationReport []byte
reportData []byte
err error
}{
{
name: "Valid attestation, malformed policy (measurement)",
attestationReport: file,
reportData: reportData,
err: errAttVerification,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
}
func prepareForTestVerifyAttestationReport(t *testing.T) ([]byte, []byte) {
file, err := os.ReadFile("../../../attestation.bin")
require.NoError(t, err)
rr, err := abi.ReportCertsToProto(file)
require.NoError(t, err)
if len(file) < attestationReportSize {
file = append(file, make([]byte, attestationReportSize-len(file))...)
}
AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
backendinfoFile, err := os.ReadFile("../../../scripts/backend_info/backend_info.json")
require.NoError(t, err)
err = protojson.Unmarshal(backendinfoFile, &AttConfigurationSEVSNP)
require.NoError(t, err)
AttConfigurationSEVSNP.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
AttConfigurationSEVSNP.Policy.FamilyId = rr.Report.FamilyId
AttConfigurationSEVSNP.Policy.ImageId = rr.Report.ImageId
AttConfigurationSEVSNP.Policy.Measurement = rr.Report.Measurement
AttConfigurationSEVSNP.Policy.HostData = rr.Report.HostData
AttConfigurationSEVSNP.Policy.ReportIdMa = rr.Report.ReportIdMa
AttConfigurationSEVSNP.RootOfTrust.ProductLine = sevProductNameMilan
return file, rr.Report.ReportData
}
+8 -6
View File
@@ -20,13 +20,15 @@ func NewAgentClient(ctx context.Context, cfg grpc.Config) (grpc.Client, agent.Ag
return nil, nil, err
}
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "agent",
})
if client.Secure() != grpc.WithATLS {
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "agent",
})
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
return nil, nil, errors.Wrap(err, ErrAgentServiceUnavailable)
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
return nil, nil, errors.Wrap(err, ErrAgentServiceUnavailable)
}
}
return client, agent.NewAgentServiceClient(client.Connection()), nil
+43 -122
View File
@@ -3,27 +3,24 @@
package grpc
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/json"
"fmt"
"net"
"os"
"path"
"strconv"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/validate"
"github.com/google/go-sev-guest/verify"
"github.com/google/go-sev-guest/verify/trust"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/encoding/protojson"
)
type security int
@@ -32,14 +29,12 @@ const (
withoutTLS security = iota
withTLS
withmTLS
withaTLS
)
const (
cocosDirectory = ".cocos"
caBundleName = "ask_ark.pem"
productNameMilan = "Milan"
productNameGenoa = "Genoa"
attestationReportSize = 0x4A0
AttestationReportSize = 0x4A0
WithATLS = "with aTLS"
)
var (
@@ -49,20 +44,11 @@ var (
ErrBackendInfoMissing = errors.New("failed due to missing backend info file")
ErrBackendInfoDecode = errors.New("failed to decode backend info file")
errCertificateParse = errors.New("failed to parse x509 certificate")
errAttVerification = errors.New("attestation verification failed")
errAttValidation = errors.New("attestation validation failed")
errCustomExtension = errors.New("failed due to missing custom extension")
errAttVerification = errors.New("certificat is not sefl signed")
errFailedToLoadClientCertKey = errors.New("failed to load client certificate and key")
errFailedToLoadRootCA = errors.New("failed to load root ca file")
)
var (
customSEVSNPExtensionOID = asn1.ObjectIdentifier{1, 2, 3, 4, 5, 6}
attestationConfiguration = AttestationConfiguration{}
timeout = time.Minute * 2
maxTryDelay = time.Second * 30
)
type Config struct {
ClientCert string `env:"CLIENT_CERT" envDefault:""`
ClientKey string `env:"CLIENT_KEY" envDefault:""`
@@ -73,11 +59,6 @@ type Config struct {
BackendInfo string `env:"BACKEND_INFO" envDefault:""`
}
type AttestationConfiguration struct {
SNPPolicy *check.Policy `json:"snp_policy,omitempty"`
RootOfTrust *check.RootOfTrust `json:"root_of_trust,omitempty"`
}
type Client interface {
// Close closes gRPC connection.
Close() error
@@ -124,6 +105,8 @@ func (c *client) Secure() string {
return "with TLS"
case withmTLS:
return "with mTLS"
case withaTLS:
return WithATLS
case withoutTLS:
fallthrough
default:
@@ -144,16 +127,18 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
tc := insecure.NewCredentials()
if cfg.AttestedTLS {
err := ReadBackendInfo(cfg.BackendInfo, &attestationConfiguration)
err := ReadBackendInfo(cfg.BackendInfo, &quoteprovider.AttConfigurationSEVSNP)
if err != nil {
return nil, secure, errors.Wrap(fmt.Errorf("failed to read Backend Info"), err)
}
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyAttestationReportTLS,
VerifyPeerCertificate: verifyPeerCertificateATLS,
}
tc = credentials.NewTLS(tlsConfig)
opts = append(opts, grpc.WithContextDialer(CustomDialer))
secure = withaTLS
} else {
if cfg.ServerCAFile != "" {
tlsConfig := &tls.Config{}
@@ -195,17 +180,14 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
return conn, secure, nil
}
func ReadBackendInfo(manifestPath string, attestationConfiguration *AttestationConfiguration) error {
func ReadBackendInfo(manifestPath string, attestationConfiguration *check.Config) error {
if manifestPath != "" {
manifest, err := os.Open(manifestPath)
manifest, err := os.ReadFile(manifestPath)
if err != nil {
return errors.Wrap(errBackendInfoOpen, err)
}
defer manifest.Close()
decoder := json.NewDecoder(manifest)
err = decoder.Decode(attestationConfiguration)
if err != nil {
if err := protojson.Unmarshal(manifest, attestationConfiguration); err != nil {
return errors.Wrap(ErrBackendInfoDecode, err)
}
@@ -215,69 +197,37 @@ func ReadBackendInfo(manifestPath string, attestationConfiguration *AttestationC
return ErrBackendInfoMissing
}
func verifyAttestationReportTLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
func CustomDialer(ctx context.Context, addr string) (net.Conn, error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("could not create a custom dialer")
}
p, err := strconv.Atoi(port)
if err != nil {
return nil, fmt.Errorf("bad format of IP address: %v", err)
}
conn, err := atls.DialTLSClient(ip, p)
if err != nil {
return nil, fmt.Errorf("could not create TLS connection")
}
return conn, nil
}
func verifyPeerCertificateATLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return errors.Wrap(errCertificateParse, err)
}
for _, ext := range cert.Extensions {
if ext.Id.Equal(customSEVSNPExtensionOID) {
// Check if the certificate is self-signed
err := checkIfCertificateSelfSigned(cert)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
expectedReportData := sha3.Sum512(publicKeyBytes)
attestationConfiguration.SNPPolicy.ReportData = expectedReportData[:]
// Attestation verification and validation
sopts, err := verify.RootOfTrustToOptions(attestationConfiguration.RootOfTrust)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
sopts.Product = attestationConfiguration.SNPPolicy.Product
sopts.Getter = &trust.RetryHTTPSGetter{
Timeout: timeout,
MaxRetryDelay: maxTryDelay,
Getter: &trust.SimpleHTTPSGetter{},
}
attestation_bytes := ext.Value[:attestationReportSize]
attestationPB, err := abi.ReportCertsToProto(attestation_bytes)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
if err := fillInAttestationLocal(attestationPB); err != nil {
return err
}
if err = verify.SnpAttestation(attestationPB, sopts); err != nil {
return errors.Wrap(errAttVerification, err)
}
opts, err := validate.PolicyToOptions(attestationConfiguration.SNPPolicy)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
return errors.Wrap(errAttValidation, err)
}
return nil
}
err = checkIfCertificateSelfSigned(cert)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
return errCustomExtension
return nil
}
func checkIfCertificateSelfSigned(cert *x509.Certificate) error {
@@ -295,32 +245,3 @@ func checkIfCertificateSelfSigned(cert *x509.Certificate) error {
return nil
}
func fillInAttestationLocal(attestation *sevsnp.Attestation) error {
product := attestationConfiguration.RootOfTrust.ProductLine
chain := attestation.GetCertificateChain()
if chain == nil {
chain = &sevsnp.CertificateChain{}
attestation.CertificateChain = chain
}
if len(chain.GetAskCert()) == 0 || len(chain.GetArkCert()) == 0 {
homePath, err := os.UserHomeDir()
if err != nil {
return err
}
bundleFilePath := path.Join(homePath, cocosDirectory, product, caBundleName)
if _, err := os.Stat(bundleFilePath); err == nil {
amdRootCerts := trust.AMDRootCerts{}
if err := amdRootCerts.FromKDSCert(bundleFilePath); err != nil {
return err
}
chain.ArkCert = amdRootCerts.ProductCerts.Ark.Raw
chain.AskCert = amdRootCerts.ProductCerts.Ask.Raw
}
}
return nil
}
+4 -208
View File
@@ -7,7 +7,6 @@ import (
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
@@ -16,14 +15,9 @@ import (
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/tools/lib/report"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/sha3"
"google.golang.org/protobuf/proto"
)
func TestNewClient(t *testing.T) {
@@ -151,7 +145,7 @@ func TestClientSecure(t *testing.T) {
}
func TestReadBackendInfo(t *testing.T) {
validJSON := `{"snp_policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}`
validJSON := `{"policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}`
invalidJSON := `{"invalid_json"`
cases := []struct {
@@ -194,12 +188,12 @@ func TestReadBackendInfo(t *testing.T) {
defer os.Remove(tt.manifestPath)
}
config := &AttestationConfiguration{}
err := ReadBackendInfo(tt.manifestPath, config)
config := check.Config{}
err := ReadBackendInfo(tt.manifestPath, &config)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
if tt.err == nil {
assert.NotNil(t, config.SNPPolicy)
assert.NotNil(t, config.Policy)
assert.NotNil(t, config.RootOfTrust)
}
})
@@ -292,132 +286,6 @@ func createTempFileHandle() (*os.File, error) {
return os.CreateTemp("", "test")
}
func TestVerifyAttestationReportTLS(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
file, err := os.ReadFile("../../../attestation.bin")
require.NoError(t, err)
rr, err := abi.ReportCertsToProto(file)
require.NoError(t, err)
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)
expectedReportData := sha3.Sum512(publicKeyBytes)
rr.Report.ReportData = expectedReportData[:]
file2, err := report.Transform(rr, "bin")
require.NoError(t, err)
file3, err := proto.Marshal(rr)
require.NoError(t, err)
if len(file) < attestationReportSize {
file = append(file, make([]byte, attestationReportSize-len(file))...)
}
if len(file2) < attestationReportSize {
file2 = append(file2, make([]byte, attestationReportSize-len(file2))...)
}
if len(file3) < attestationReportSize {
file3 = append(file3, make([]byte, attestationReportSize-len(file3))...)
}
template.ExtraExtensions = []pkix.Extension{
{
Id: customSEVSNPExtensionOID,
Value: file,
},
}
template2 := template
template2.ExtraExtensions[0].Value = file2
template3 := template
template3.ExtraExtensions[0].Value = file3
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certDERBadSig, err := x509.CreateCertificate(rand.Reader, &template2, &template2, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certDERMalformed, err := x509.CreateCertificate(rand.Reader, &template3, &template3, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
backendinfoFile, err := os.ReadFile("../../../scripts/backend_info/backend_info.json")
require.NoError(t, err)
attestationConfiguration = AttestationConfiguration{
SNPPolicy: &check.Policy{},
RootOfTrust: &check.RootOfTrust{},
}
err = json.Unmarshal(backendinfoFile, &attestationConfiguration)
require.NoError(t, err)
attestationConfiguration.SNPPolicy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
attestationConfiguration.SNPPolicy.FamilyId = rr.Report.FamilyId
attestationConfiguration.SNPPolicy.ImageId = rr.Report.ImageId
attestationConfiguration.SNPPolicy.Measurement = rr.Report.Measurement
attestationConfiguration.SNPPolicy.HostData = rr.Report.HostData
attestationConfiguration.SNPPolicy.ReportIdMa = rr.Report.ReportIdMa
attestationConfiguration.RootOfTrust.ProductLine = "Milan"
tests := []struct {
name string
rawCerts [][]byte
err error
}{
{
name: "Valid certificate with attestation, validation fails on report data",
rawCerts: [][]byte{certDER},
err: errAttVerification,
},
{
name: "Valid certificate with attestation, distorted signature",
rawCerts: [][]byte{certDERBadSig},
err: errAttVerification,
},
{
name: "Valid certificate with attestation, malformed policy",
rawCerts: [][]byte{certDERMalformed},
err: errAttVerification,
},
{
name: "Invalid certificate",
rawCerts: [][]byte{[]byte("invalid cert")},
err: errCertificateParse,
},
{
name: "Certificate without custom extension",
rawCerts: [][]byte{createCertWithoutCustomExtension(t)},
err: errCustomExtension,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := verifyAttestationReportTLS(tt.rawCerts, nil)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
}
func TestCheckIfCertificateSelfSigned(t *testing.T) {
selfSignedCert := createSelfSignedCert(t)
@@ -446,78 +314,6 @@ func TestCheckIfCertificateSelfSigned(t *testing.T) {
}
}
func TestFillInAttestationLocal(t *testing.T) {
tempDir, err := os.MkdirTemp("", "test_home")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
cocosDir := tempDir + "/.cocos/Milan"
err = os.MkdirAll(cocosDir, 0o755)
require.NoError(t, err)
bundleContent := []byte("mock ASK ARK bundle")
err = os.WriteFile(cocosDir+"/ask_ark.pem", bundleContent, 0o644)
require.NoError(t, err)
attestationConfiguration = AttestationConfiguration{
RootOfTrust: &check.RootOfTrust{},
SNPPolicy: &check.Policy{},
}
tests := []struct {
name string
attestation *sevsnp.Attestation
err error
}{
{
name: "Empty attestation",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
err: nil,
},
{
name: "Attestation with existing chain",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("existing ASK cert"),
ArkCert: []byte("existing ARK cert"),
},
},
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := fillInAttestationLocal(tt.attestation)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
}
func createCertWithoutCustomExtension(t *testing.T) []byte {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
return certDER
}
func createSelfSignedCert(t *testing.T) *x509.Certificate {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
+1
View File
@@ -10,3 +10,4 @@ clap = { version = "4.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sev = "4.0.0"
base64 = "0.22.1"
+19 -92
View File
@@ -1,101 +1,28 @@
{
"snp_policy": {
"policy": 1966081,
"family_id": [
0
],
"image_id": [
0
],
"vmpl": {
"value": 0
},
"minimum_tcb": 15063977803600887811,
"minimum_launch_tcb": 15063977803600887811,
"policy": {
"policy": 196608,
"family_id": "AAAAAAAAAAAAAAAAAAAAAA==",
"image_id": "AAAAAAAAAAAAAAAAAAAAAA==",
"vmpl": 0,
"minimum_tcb": 15066229603414573059,
"minimum_launch_tcb": 15066229603414573059,
"require_author_key": false,
"measurement": [
0
],
"host_data": [
0
],
"report_id_ma": [
0
],
"chip_id": [
26,
177,
106,
181,
15,
165,
174,
66,
236,
140,
27,
37,
187,
218,
92,
11,
165,
234,
146,
187,
69,
89,
141,
64,
172,
132,
62,
35,
136,
46,
129,
2,
44,
188,
33,
180,
169,
233,
18,
188,
75,
68,
224,
255,
210,
45,
34,
122,
152,
115,
105,
58,
70,
52,
48,
121,
198,
166,
252,
245,
58,
69,
126,
147
],
"minimum_build": 7,
"measurement": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
"host_data": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
"report_id_ma": "//////////////////////////////////////////8=",
"chip_id": "GrFqtQ+lrkLsjBslu9pcC6XqkrtFWY1ArIQ+I4gugQIsvCG0qekSvEtE4P/SLSJ6mHNpOkY0MHnGpvz1OkV+kw==",
"minimum_build": 21,
"minimum_version": "1.55",
"permit_provisional_firmware": false,
"require_id_block": false
"require_id_block": false,
"product": {
"name": 1
}
},
"root_of_trust": {
"product": "Milan",
"check_crl": true,
"disallow_network": false
"disallow_network": false,
"product_line": "Milan"
}
}
}
+18 -22
View File
@@ -1,3 +1,4 @@
use base64::prelude::*;
use clap::{value_parser, Arg, Command};
use serde::Serialize;
use sev::firmware::host::*;
@@ -23,24 +24,19 @@ struct SevProduct {
name: i32,
}
#[derive(Serialize)]
struct Vmpl {
value: u32,
}
#[derive(Serialize)]
struct SnpPolicy {
policy: u64,
family_id: Vec<u8>,
image_id: Vec<u8>,
vmpl: Vmpl,
family_id: String,
image_id: String,
vmpl: u32,
minimum_tcb: u64,
minimum_launch_tcb: u64,
require_author_key: bool,
measurement: Vec<u8>,
host_data: Vec<u8>,
report_id_ma: Vec<u8>,
chip_id: Vec<u8>,
measurement: String,
host_data: String,
report_id_ma: String,
chip_id: String,
minimum_build: u32,
minimum_version: String,
permit_provisional_firmware: bool,
@@ -58,7 +54,7 @@ struct RootOfTrust {
#[derive(Serialize)]
struct Computation {
snp_policy: SnpPolicy,
policy: SnpPolicy,
root_of_trust: RootOfTrust,
}
@@ -125,24 +121,24 @@ fn main() {
let status: SnpPlatformStatus = firmware.snp_platform_status().unwrap();
let policy: u64 = *matches.get_one::<u64>("policy").unwrap();
let family_id = vec![0; 16];
let image_id = vec![0; 16];
let vmpl = Vmpl { value: 0 };
let family_id = BASE64_STANDARD.encode(vec![0; 16]);
let image_id = BASE64_STANDARD.encode(vec![0; 16]);
let vmpl = 0;
let minimum_tcb = get_uint64_from_tcb(&status.platform_tcb_version);
let minimum_launch_tcb = get_uint64_from_tcb(&status.platform_tcb_version);
let require_author_key = false;
let measurement = vec![0];
let host_data = vec![0];
let report_id_ma = vec![0xFF; 32];
let measurement = BASE64_STANDARD.encode(vec![0; 48]);
let host_data = BASE64_STANDARD.encode(vec![0; 32]);
let report_id_ma = BASE64_STANDARD.encode(vec![0xFF; 32]);
let cpu_id: Identifier = firmware.get_identifier().unwrap();
let chip_id: Vec<u8> = cpu_id.0;
let chip_id: String = BASE64_STANDARD.encode(cpu_id.0);
let minimum_build = status.build_id;
let minimum_version = status.version.to_string();
let permit_provisional_firmware = false;
let require_id_block = false;
let product = sev_product(get_sev_snp_processor());
let snp_policy = SnpPolicy {
let policy = SnpPolicy {
policy,
family_id,
image_id,
@@ -169,7 +165,7 @@ fn main() {
};
let computation = Computation {
snp_policy,
policy,
root_of_trust,
};