mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-192 - Add support for attested TLS (#279)
* add draft tls extension * add client support for ipv6 * remove vscode * add evidence request server payload * clean up the code * add fetch and verify for quote provider * add build parameters for buildroot * change Makefile to always enable CGO * fix ci * add malloc check for NULL * add copyright * renamed files and fix cgo lint * fix cache test * fix server tests * remove ineffective assignment * fix no-TLS connection * add check for SSL_set_fd failure * add tests for verification of attestation * fix CI * fix failing tests * fix backend tests * remove commented code * separate verify and validate function * fix failing test * Simplify function name --------- Co-authored-by: ultraviolet <cocosai@ultraviolet.local.pragmatic-it.com>
This commit is contained in:
committed by
GitHub
parent
6f747190b9
commit
e372cfc219
@@ -1,7 +1,7 @@
|
||||
BUILD_DIR = build
|
||||
SERVICES = manager agent cli
|
||||
BACKEND_INFO = backend_info
|
||||
CGO_ENABLED ?= 0
|
||||
CGO_ENABLED ?= 1
|
||||
GOARCH ?= amd64
|
||||
VERSION ?= $(shell git describe --abbrev=0 --tags --always)
|
||||
COMMIT ?= $(shell git rev-parse HEAD)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !embed
|
||||
// +build !embed
|
||||
|
||||
package quoteprovider
|
||||
|
||||
import "github.com/google/go-sev-guest/client"
|
||||
|
||||
func GetQuoteProvider() (client.QuoteProvider, error) {
|
||||
return client.GetQuoteProvider()
|
||||
}
|
||||
@@ -18,10 +18,10 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/quoteprovider"
|
||||
mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
mocks2 "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
+3
-56
@@ -17,12 +17,10 @@ import (
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-sev-guest/tools/lib/report"
|
||||
"github.com/google/go-sev-guest/validate"
|
||||
"github.com/google/go-sev-guest/verify"
|
||||
"github.com/google/go-sev-guest/verify/trust"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
@@ -269,7 +267,7 @@ func (cli *CLI) NewValidateAttestationValidationCmd() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
if err := verifyAndValidateAttestation(attestation); err != nil {
|
||||
if err := quoteprovider.VerifyAndValidate(attestation, &cfg); err != nil {
|
||||
printError(cmd, "Attestation validation and verification failed with error: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
@@ -464,57 +462,6 @@ func (cli *CLI) NewValidateAttestationValidationCmd() *cobra.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func verifyAndValidateAttestation(attestation []byte) error {
|
||||
sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Policy.Product == nil {
|
||||
productName := sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN
|
||||
switch cfg.RootOfTrust.ProductLine {
|
||||
case sevProductNameMilan:
|
||||
productName = sevsnp.SevProduct_SEV_PRODUCT_MILAN
|
||||
case sevProductNameGenoa:
|
||||
productName = sevsnp.SevProduct_SEV_PRODUCT_GENOA
|
||||
default:
|
||||
}
|
||||
|
||||
if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN {
|
||||
return fmt.Errorf("product name must be %s or %s", sevProductNameMilan, sevProductNameGenoa)
|
||||
}
|
||||
|
||||
sopts.Product = &sevsnp.SevProduct{
|
||||
Name: productName,
|
||||
}
|
||||
} else {
|
||||
sopts.Product = cfg.Policy.Product
|
||||
}
|
||||
|
||||
sopts.Getter = &trust.RetryHTTPSGetter{
|
||||
Timeout: timeout,
|
||||
MaxRetryDelay: maxRetryDelay,
|
||||
Getter: &trust.SimpleHTTPSGetter{},
|
||||
}
|
||||
|
||||
// Only take the attestation report and ignore everything else.
|
||||
attestationPB, err := abi.ReportCertsToProto(attestation[:abi.ReportSize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = verify.SnpAttestation(attestationPB, sopts); err != nil {
|
||||
return err
|
||||
}
|
||||
opts, err := validate.PolicyToOptions(cfg.Policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseConfig decodes config passed as json for check.Config struct.
|
||||
// example
|
||||
/* {
|
||||
@@ -681,7 +628,7 @@ func getBase(val string) int {
|
||||
}
|
||||
|
||||
func validateInput() error {
|
||||
if len(cfg.RootOfTrust.CabundlePaths) != 0 || len(cfg.RootOfTrust.Cabundles) != 0 && cfg.RootOfTrust.Product == "" {
|
||||
if len(cfg.RootOfTrust.CabundlePaths) != 0 || len(cfg.RootOfTrust.Cabundles) != 0 && cfg.RootOfTrust.ProductLine == "" {
|
||||
return fmt.Errorf("product name must be set if CA bundles are provided")
|
||||
}
|
||||
|
||||
|
||||
+6
-11
@@ -4,7 +4,6 @@ package cli
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
type fieldType int
|
||||
@@ -39,11 +39,6 @@ var (
|
||||
errBackendField = errors.New("the specified field type does not exist in the backend information")
|
||||
)
|
||||
|
||||
type AttestationConfiguration struct {
|
||||
SNPPolicy *check.Policy `json:"snp_policy,omitempty"`
|
||||
RootOfTrust *check.RootOfTrust `json:"root_of_trust,omitempty"`
|
||||
}
|
||||
|
||||
func (cli *CLI) NewBackendCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "backend [command]",
|
||||
@@ -114,27 +109,27 @@ func changeAttestationConfiguration(fileName, base64Data string, expectedLength
|
||||
return errDataLength
|
||||
}
|
||||
|
||||
ac := AttestationConfiguration{}
|
||||
ac := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
|
||||
backendInfo, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return errors.Wrap(errReadingBackendInfoFile, err)
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(backendInfo, &ac); err != nil {
|
||||
if err = protojson.Unmarshal(backendInfo, &ac); err != nil {
|
||||
return errors.Wrap(errUnmarshalJSON, err)
|
||||
}
|
||||
|
||||
switch field {
|
||||
case measurementField:
|
||||
ac.SNPPolicy.Measurement = data
|
||||
ac.Policy.Measurement = data
|
||||
case hostDataField:
|
||||
ac.SNPPolicy.HostData = data
|
||||
ac.Policy.HostData = data
|
||||
default:
|
||||
return errBackendField
|
||||
}
|
||||
|
||||
fileJson, err := json.MarshalIndent(ac, "", " ")
|
||||
fileJson, err := protojson.Marshal(&ac)
|
||||
if err != nil {
|
||||
return errors.Wrap(errMarshalJSON, err)
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ package cli
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
func TestChangeAttestationConfiguration(t *testing.T) {
|
||||
@@ -18,14 +18,9 @@ func TestChangeAttestationConfiguration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
initialConfig := AttestationConfiguration{
|
||||
SNPPolicy: &check.Policy{
|
||||
Measurement: make([]byte, measurementLength),
|
||||
HostData: make([]byte, hostDataLength),
|
||||
},
|
||||
}
|
||||
initialConfig := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
|
||||
initialJSON, err := json.Marshal(initialConfig)
|
||||
initialJSON, err := protojson.Marshal(&initialConfig)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(tmpfile.Name(), initialJSON, 0o644)
|
||||
require.NoError(t, err)
|
||||
@@ -91,15 +86,15 @@ func TestChangeAttestationConfiguration(t *testing.T) {
|
||||
content, err := os.ReadFile(tmpfile.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
var config AttestationConfiguration
|
||||
err = json.Unmarshal(content, &config)
|
||||
config := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
err = protojson.Unmarshal(content, &config)
|
||||
require.NoError(t, err)
|
||||
|
||||
decodedData, _ := base64.StdEncoding.DecodeString(tt.base64Data)
|
||||
if tt.field == measurementField {
|
||||
assert.Equal(t, decodedData, config.SNPPolicy.Measurement)
|
||||
assert.Equal(t, decodedData, config.Policy.Measurement)
|
||||
} else if tt.field == hostDataField {
|
||||
assert.Equal(t, decodedData, config.SNPPolicy.HostData)
|
||||
assert.Equal(t, decodedData, config.Policy.HostData)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
+5
-4
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/kds"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/verify/trust"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
@@ -26,14 +27,14 @@ func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command {
|
||||
Example: "ca-bundle <path_to_platform_info_json>",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
attestationConfiguration := grpc.AttestationConfiguration{}
|
||||
attestationConfiguration := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
err := grpc.ReadBackendInfo(args[0], &attestationConfiguration)
|
||||
if err != nil {
|
||||
printError(cmd, "Error while reading manifest: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
product := attestationConfiguration.RootOfTrust.Product
|
||||
product := attestationConfiguration.RootOfTrust.ProductLine
|
||||
|
||||
getter := trust.DefaultHTTPSGetter()
|
||||
caURL := kds.ProductCertChainURL(abi.VcekReportSigner, product)
|
||||
@@ -54,8 +55,8 @@ func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
bundleFilePath := path.Join(fileSavePath, product, caBundleName)
|
||||
if err = saveToFile(bundleFilePath, bundle); err != nil {
|
||||
bundlePath := path.Join(fileSavePath, product, caBundleName)
|
||||
if err = saveToFile(bundlePath, bundle); err != nil {
|
||||
printError(cmd, "Error while saving ARK-ASK to file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@ func TestNewCABundleCmd(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
manifestContent := []byte(`{"root_of_trust": {"product": "Milan"}}`)
|
||||
manifestContent := []byte(`{"root_of_trust": {"product_line": "Milan"}}`)
|
||||
manifestPath := path.Join(tempDir, "manifest.json")
|
||||
err = os.WriteFile(manifestPath, manifestContent, 0o644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
+1
-1
@@ -24,13 +24,13 @@ import (
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
"github.com/ultravioletrs/cocos/agent/quoteprovider"
|
||||
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
managerevents "github.com/ultravioletrs/cocos/manager/events"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
qpmocks "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
|
||||
qpmocks "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
|
||||
)
|
||||
|
||||
func TestSetDefaultValues(t *testing.T) {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -25,8 +24,8 @@ import (
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@@ -45,6 +44,7 @@ const (
|
||||
notAfterYear = 1
|
||||
notAfterMonth = 0
|
||||
notAfterDay = 0
|
||||
nonceSize = 32
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@@ -89,15 +89,12 @@ func (s *Server) Start() error {
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
creds := grpc.Creds(insecure.NewCredentials())
|
||||
var listener net.Listener = nil
|
||||
|
||||
switch {
|
||||
case s.Config.AttestedTLS:
|
||||
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.quoteProvider)
|
||||
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
@@ -113,7 +110,17 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
|
||||
|
||||
listener, err = atls.Listen(
|
||||
s.Address,
|
||||
certificateBytes,
|
||||
privateKeyBytes,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Listener for aTLS: %w", err)
|
||||
}
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
|
||||
|
||||
case s.Config.CertFile != "" || s.Config.KeyFile != "":
|
||||
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)
|
||||
if err != nil {
|
||||
@@ -161,7 +168,18 @@ func (s *Server) Start() error {
|
||||
default:
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
|
||||
}
|
||||
|
||||
listener, err = net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
default:
|
||||
var err error
|
||||
|
||||
listener, err = net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
|
||||
}
|
||||
|
||||
@@ -237,24 +255,13 @@ func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
|
||||
return tls.X509KeyPair(cert, key)
|
||||
}
|
||||
|
||||
func generateCertificatesForATLS(qp client.QuoteProvider) ([]byte, []byte, error) {
|
||||
func generateCertificatesForATLS() ([]byte, []byte, error) {
|
||||
curve := elliptic.P256()
|
||||
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate private/public key: %w", err)
|
||||
}
|
||||
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to marshal the public key: %w", err)
|
||||
}
|
||||
|
||||
// The Attestation Report will be added as an X.509 certificate extension
|
||||
attestationReport, err := qp.GetRawQuote(sha3.Sum512(publicKeyBytes))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to fetch the attestation report: %w", err)
|
||||
}
|
||||
|
||||
certTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(202403311),
|
||||
Subject: pkix.Name{
|
||||
@@ -270,13 +277,6 @@ func generateCertificatesForATLS(qp client.QuoteProvider) ([]byte, []byte, error
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: asn1.ObjectIdentifier{1, 2, 3, 4, 5, 6},
|
||||
Critical: false,
|
||||
Value: attestationReport,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
certDERBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
|
||||
|
||||
@@ -18,10 +18,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
authmocks "github.com/ultravioletrs/cocos/agent/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
@@ -139,7 +138,6 @@ func TestServerStartWithAttestedTLS(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
qp := new(mocks.QuoteProvider)
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
qp.On("GetRawQuote", mock.Anything).Return([]byte("mock-quote"), nil)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
|
||||
|
||||
@@ -158,7 +156,7 @@ func TestServerStartWithAttestedTLS(t *testing.T) {
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
time.Sleep(1000 * time.Millisecond)
|
||||
|
||||
logContent := logBuffer.String()
|
||||
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with Attested TLS")
|
||||
|
||||
@@ -9,22 +9,22 @@ package manager
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/virtee/sev-snp-measure-go/cpuid"
|
||||
"github.com/virtee/sev-snp-measure-go/guest"
|
||||
"github.com/virtee/sev-snp-measure-go/vmmtypes"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const defGuestFeatures = 0x1
|
||||
|
||||
func (ms *managerService) FetchBackendInfo(_ context.Context, computationId string) ([]byte, error) {
|
||||
cmd := exec.Command("sudo", fmt.Sprintf("%s/backend_info", ms.backendMeasurementBinaryPath), "--policy", "1966081")
|
||||
cmd := exec.Command("sudo", fmt.Sprintf("%s/backend_info", ms.backendMeasurementBinaryPath), "--policy", "196608")
|
||||
|
||||
ms.mu.Lock()
|
||||
vm, exists := ms.vms[computationId]
|
||||
@@ -48,9 +48,9 @@ func (ms *managerService) FetchBackendInfo(_ context.Context, computationId stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var backendInfo grpc.AttestationConfiguration
|
||||
var backendInfo check.Config
|
||||
|
||||
if err = json.Unmarshal(f, &backendInfo); err != nil {
|
||||
if err = protojson.Unmarshal(f, &backendInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func (ms *managerService) FetchBackendInfo(_ context.Context, computationId stri
|
||||
}
|
||||
}
|
||||
if measurement == nil {
|
||||
backendInfo.SNPPolicy.Measurement = measurement
|
||||
backendInfo.Policy.Measurement = measurement
|
||||
}
|
||||
|
||||
if config.HostData != "" {
|
||||
@@ -76,10 +76,10 @@ func (ms *managerService) FetchBackendInfo(_ context.Context, computationId stri
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
backendInfo.SNPPolicy.HostData = hostData
|
||||
backendInfo.Policy.HostData = hostData
|
||||
}
|
||||
|
||||
f, err = json.Marshal(backendInfo)
|
||||
f, err = protojson.Marshal(&backendInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func createDummyBackendInfoBinary(t *testing.T, behavior string) string {
|
||||
switch behavior {
|
||||
case "success":
|
||||
content = []byte(`#!/bin/sh
|
||||
echo '{"snp_policy": {"measurement": null, "host_data": null}}' > backend_info.json
|
||||
echo '{"policy": {"measurement": null, "host_data": null}}' > backend_info.json
|
||||
`)
|
||||
case "fail":
|
||||
content = []byte(`#!/bin/sh
|
||||
|
||||
@@ -0,0 +1,419 @@
|
||||
// Code generated by cmd/cgo; DO NOT EDIT.
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
// #cgo LDFLAGS: -lssl -lcrypto
|
||||
// #include "extensions.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime/cgo"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
)
|
||||
|
||||
const (
|
||||
NoTee int = iota
|
||||
AmdSevSnp
|
||||
)
|
||||
|
||||
const (
|
||||
NO_ERROR = 0
|
||||
ERROR_ZERO_RETURN = 6
|
||||
ERROR_WANT_READ = 2
|
||||
ERROR_WANT_WRITE = 3
|
||||
ERROR_SYSCALL = 5
|
||||
ERROR_SSL = 1
|
||||
)
|
||||
|
||||
var (
|
||||
errListener = errors.New("listener could not be created")
|
||||
errBadIPFormat = errors.New("bad format of IP address")
|
||||
errCloseTLS = errors.New("could not close TLS connection")
|
||||
errConnFailed = errors.New("tls connection is nil")
|
||||
errWrite = errors.New("could not write to TLS")
|
||||
errTLSConn = errors.New("connection did not close correctly")
|
||||
errReadDeadline = errors.New("could not set read deadline, socket timeout failed")
|
||||
errWriteDeadline = errors.New("could not set write deadline, socket timeout failed")
|
||||
errConnCreate = errors.New("could not create connection")
|
||||
)
|
||||
|
||||
type ValidationVerification func(data1, data2 []byte) error
|
||||
type FetchAttestation func(data1 []byte) ([]byte, error)
|
||||
|
||||
func registerFetchAttestation(callback FetchAttestation) uintptr {
|
||||
handle := cgo.NewHandle(callback)
|
||||
return uintptr(handle)
|
||||
}
|
||||
|
||||
func registerValidationVerification(callback ValidationVerification) uintptr {
|
||||
handle := cgo.NewHandle(callback)
|
||||
return uintptr(handle)
|
||||
}
|
||||
|
||||
//export validationVerificationCallback
|
||||
func validationVerificationCallback(teeType C.int) uintptr {
|
||||
switch int(teeType) {
|
||||
case NoTee:
|
||||
return uintptr(0)
|
||||
case AmdSevSnp:
|
||||
return registerValidationVerification(quoteprovider.VerifyAttestationReportTLS)
|
||||
default:
|
||||
return uintptr(0)
|
||||
}
|
||||
}
|
||||
|
||||
//export fetchAttestationCallback
|
||||
func fetchAttestationCallback(teeType C.int) uintptr {
|
||||
switch int(teeType) {
|
||||
case NoTee:
|
||||
return uintptr(0)
|
||||
case AmdSevSnp:
|
||||
return registerFetchAttestation(quoteprovider.FetchAttestation)
|
||||
default:
|
||||
return uintptr(0)
|
||||
}
|
||||
}
|
||||
|
||||
//export callVerificationValidationCallback
|
||||
func callVerificationValidationCallback(callbackHandle uintptr, attReport *C.uchar, attReportSize C.int, repData *C.uchar) C.int {
|
||||
handle := cgo.Handle(callbackHandle)
|
||||
defer handle.Delete()
|
||||
|
||||
callback := handle.Value().(ValidationVerification)
|
||||
attestationReport := C.GoBytes(unsafe.Pointer(attReport), attReportSize)
|
||||
reportData := C.GoBytes(unsafe.Pointer(repData), agent.ReportDataSize)
|
||||
|
||||
err := callback(attestationReport, reportData)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "callback failed %v", err)
|
||||
return C.int(-1)
|
||||
}
|
||||
|
||||
return C.int(0)
|
||||
}
|
||||
|
||||
//export callFetchAttestationCallback
|
||||
func callFetchAttestationCallback(callbackHandle uintptr, reportDataByte *C.uchar, outlen *C.int) *C.uchar {
|
||||
handle := cgo.Handle(callbackHandle)
|
||||
defer handle.Delete()
|
||||
|
||||
callback := handle.Value().(FetchAttestation)
|
||||
reportData := C.GoBytes(unsafe.Pointer(reportDataByte), agent.ReportDataSize)
|
||||
|
||||
quote, err := callback(reportData)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "attestation callback returned nil")
|
||||
return nil
|
||||
}
|
||||
|
||||
*outlen = C.int(len(quote))
|
||||
resultC := C.malloc(C.size_t(len(quote)))
|
||||
if resultC == nil {
|
||||
fmt.Fprintf(os.Stderr, "could not allocate memory for fetch attestation callback")
|
||||
return nil
|
||||
}
|
||||
|
||||
C.memcpy(resultC, unsafe.Pointer("e[0]), C.size_t(len(quote)))
|
||||
|
||||
return (*C.uchar)(resultC)
|
||||
}
|
||||
|
||||
type ATLSServerListener struct {
|
||||
tlsListener *C.tls_server_connection
|
||||
}
|
||||
|
||||
func Listen(addr string, cert []byte, key []byte) (net.Listener, error) {
|
||||
ip, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errListener, err)
|
||||
}
|
||||
|
||||
p, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errBadIPFormat, err)
|
||||
}
|
||||
|
||||
cCertPEM := (*C.char)(unsafe.Pointer(&cert[0]))
|
||||
cKeyPEM := (*C.char)(unsafe.Pointer(&key[0]))
|
||||
cIP := C.CString(ip)
|
||||
defer C.free(unsafe.Pointer(cIP))
|
||||
|
||||
atlsListener := C.start_tls_server(
|
||||
cCertPEM, C.int(len(cert)),
|
||||
cKeyPEM, C.int(len(key)),
|
||||
cIP, C.int(p))
|
||||
if atlsListener == nil {
|
||||
return nil, errors.Wrap(errListener, err)
|
||||
}
|
||||
|
||||
return &ATLSServerListener{tlsListener: atlsListener}, nil
|
||||
}
|
||||
|
||||
// accept implements the Accept method in the [Listener] interface; it
|
||||
// waits for the next call and returns a generic [Conn].
|
||||
func (l *ATLSServerListener) Accept() (net.Conn, error) {
|
||||
conn := C.tls_server_accept(l.tlsListener)
|
||||
if conn == nil {
|
||||
return &ATLSConn{tlsConn: nil}, nil
|
||||
}
|
||||
|
||||
return &ATLSConn{tlsConn: conn}, nil
|
||||
}
|
||||
|
||||
// close stops listening on the TCP address.
|
||||
// already Accepted connections are not closed.
|
||||
func (l *ATLSServerListener) Close() error {
|
||||
ret := C.tls_server_close(l.tlsListener)
|
||||
if ret != 0 {
|
||||
return errCloseTLS
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addr returns the listener's network address, a [*TCPAddr].
|
||||
// the Addr returned is shared by all invocations of Addr, so
|
||||
// do not modify it.
|
||||
func (l *ATLSServerListener) Addr() net.Addr {
|
||||
cIP := C.tls_return_addr(&l.tlsListener.addr)
|
||||
defer C.free(unsafe.Pointer(cIP))
|
||||
|
||||
ip := C.GoString(cIP)
|
||||
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
port := C.tls_return_port(&l.tlsListener.addr)
|
||||
|
||||
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
|
||||
}
|
||||
|
||||
type ATLSConn struct {
|
||||
tlsConn *C.tls_connection
|
||||
fdReadMutex sync.Mutex
|
||||
fdWriteMutex sync.Mutex
|
||||
fdDelayMutex sync.Mutex
|
||||
}
|
||||
|
||||
func (c *ATLSConn) Read(b []byte) (int, error) {
|
||||
c.fdReadMutex.Lock()
|
||||
defer c.fdReadMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return 0, errConnFailed
|
||||
}
|
||||
|
||||
n := int(C.tls_read(c.tlsConn, unsafe.Pointer(&b[0]), C.int(len(b))))
|
||||
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// call the C function SSL_get_error to interpret the error.
|
||||
errCode := int(C.SSL_get_error(c.tlsConn.ssl, C.int(n)))
|
||||
|
||||
// handle specific error codes returned by SSL_get_error.
|
||||
switch errCode {
|
||||
case NO_ERROR:
|
||||
return n, nil // no error.
|
||||
case ERROR_ZERO_RETURN:
|
||||
fmt.Fprintf(os.Stderr, "Connection closed by peer")
|
||||
return 0, io.EOF // connection closed.
|
||||
case ERROR_WANT_READ:
|
||||
fmt.Fprintf(os.Stderr, "Operation read incomplete, retry later")
|
||||
return 0, nil // non-fatal, just retry later.
|
||||
case ERROR_WANT_WRITE:
|
||||
fmt.Fprintf(os.Stderr, "Operation write incomplete, retry later")
|
||||
return 0, nil // non-fatal, just retry later.
|
||||
case ERROR_SYSCALL:
|
||||
fmt.Fprintf(os.Stderr, "I/O error")
|
||||
return 0, syscall.ECONNRESET // return connection reset error.
|
||||
case ERROR_SSL:
|
||||
fmt.Fprintf(os.Stderr, "I/O error")
|
||||
return 0, syscall.ECONNRESET // return connection reset error.
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "SSL error occurred: %d\n", errCode)
|
||||
return 0, fmt.Errorf("SSL error")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ATLSConn) Write(b []byte) (int, error) {
|
||||
c.fdWriteMutex.Lock()
|
||||
defer c.fdWriteMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return 0, errConnFailed
|
||||
}
|
||||
|
||||
n := int(C.tls_write(c.tlsConn, unsafe.Pointer(&b[0]), C.int(len(b))))
|
||||
if n < 0 {
|
||||
return 0, errWrite
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *ATLSConn) Close() error {
|
||||
c.fdReadMutex.Lock()
|
||||
defer c.fdReadMutex.Unlock()
|
||||
|
||||
c.fdWriteMutex.Lock()
|
||||
defer c.fdWriteMutex.Unlock()
|
||||
|
||||
c.fdDelayMutex.Lock()
|
||||
defer c.fdDelayMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := C.tls_close(c.tlsConn)
|
||||
|
||||
if int(ret) < 0 {
|
||||
c.tlsConn = nil
|
||||
return errTLSConn
|
||||
} else if int(ret) == 1 {
|
||||
c.tlsConn = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ATLSConn) LocalAddr() net.Addr {
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
cIP := C.tls_return_addr(&c.tlsConn.local_addr)
|
||||
ipLength := C.strlen(cIP)
|
||||
defer C.free(unsafe.Pointer(cIP))
|
||||
|
||||
ip := C.GoStringN(cIP, C.int(ipLength))
|
||||
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
fmt.Println("Invalid IP address")
|
||||
return nil
|
||||
}
|
||||
|
||||
port := C.tls_return_port(&c.tlsConn.local_addr)
|
||||
|
||||
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
|
||||
}
|
||||
|
||||
func (c *ATLSConn) RemoteAddr() net.Addr {
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
cIP := C.tls_return_addr(&c.tlsConn.remote_addr)
|
||||
if cIP == nil {
|
||||
fmt.Println("RemoteAddr error while fetching ip")
|
||||
return nil
|
||||
}
|
||||
|
||||
ipLength := C.strlen(cIP)
|
||||
defer C.free(unsafe.Pointer(cIP))
|
||||
|
||||
ip := C.GoStringN(cIP, C.int(ipLength))
|
||||
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
fmt.Println("Invalid IP address")
|
||||
return nil
|
||||
}
|
||||
|
||||
port := C.tls_return_port(&c.tlsConn.remote_addr)
|
||||
|
||||
return &net.TCPAddr{IP: parsedIP, Port: int(port)}
|
||||
}
|
||||
|
||||
func (c *ATLSConn) SetDeadline(t time.Time) error {
|
||||
c.fdDelayMutex.Lock()
|
||||
defer c.fdDelayMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sec, usec := timeToTimeout(t)
|
||||
if C.set_socket_read_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
|
||||
return errReadDeadline
|
||||
}
|
||||
|
||||
if C.set_socket_write_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
|
||||
return errWriteDeadline
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ATLSConn) SetReadDeadline(t time.Time) error {
|
||||
c.fdDelayMutex.Lock()
|
||||
defer c.fdDelayMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sec, usec := timeToTimeout(t)
|
||||
if C.set_socket_read_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
|
||||
return errReadDeadline
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ATLSConn) SetWriteDeadline(t time.Time) error {
|
||||
c.fdDelayMutex.Lock()
|
||||
defer c.fdDelayMutex.Unlock()
|
||||
|
||||
if c.tlsConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sec, usec := timeToTimeout(t)
|
||||
if C.set_socket_write_timeout(c.tlsConn, C.int(sec), C.int(usec)) < 0 {
|
||||
return errWriteDeadline
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DialTLSClient(hostname string, port int) (net.Conn, error) {
|
||||
cHostName := C.CString(hostname)
|
||||
defer C.free(unsafe.Pointer(cHostName))
|
||||
|
||||
conn := C.new_tls_connection(cHostName, C.int(port))
|
||||
if conn == nil {
|
||||
return nil, errConnCreate
|
||||
}
|
||||
|
||||
return &ATLSConn{tlsConn: conn}, nil
|
||||
}
|
||||
|
||||
func timeToTimeout(t time.Time) (int, int) {
|
||||
if t.IsZero() {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
d := time.Until(t)
|
||||
if d <= 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
seconds := int(d.Seconds())
|
||||
microseconds := int(d.Nanoseconds()/1000) % 1_000_000
|
||||
return seconds, microseconds
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
#include "extensions.h"
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <openssl/sha.h>
|
||||
#include <openssl/rand.h>
|
||||
#include <openssl/x509.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
extern int callVerificationValidationCallback(uintptr_t callbackHandle, const u_char* attReport, int attReportSize, const u_char* repData);
|
||||
extern u_char* callFetchAttestationCallback(uintptr_t callbackHandle, const u_char* reportDataByte, int* outlen);
|
||||
extern uintptr_t validationVerificationCallback(int teeType);
|
||||
extern uintptr_t fetchAttestationCallback(int teeType);
|
||||
|
||||
int triggerVerificationValidationCallback(uintptr_t callbackHandle, u_char *attestationReport, int reportSize, u_char *reportData) {
|
||||
if (attestationReport == NULL || reportData == NULL) {
|
||||
fprintf(stderr, "attestation data and report data cannot be NULL\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
return callVerificationValidationCallback(callbackHandle, attestationReport, reportSize, reportData);
|
||||
}
|
||||
|
||||
u_char* triggerFetchAttestationCallback(uintptr_t callbackHandle, char *reportData) {
|
||||
int outlen = REPORT_DATA_SIZE;
|
||||
|
||||
if(reportData == NULL) {
|
||||
fprintf(stderr, "Report data cannot be NULL");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return callFetchAttestationCallback(callbackHandle, reportData, &outlen);
|
||||
}
|
||||
|
||||
int check_sev_snp() {
|
||||
int fd = open(SEV_GUEST_DRIVER_PATH, O_RDONLY);
|
||||
|
||||
if (fd == -1) {
|
||||
perror("Error opening /dev/sev-guest");
|
||||
fprintf(stderr, "SEV guest driver is not available.\n");
|
||||
return -1;
|
||||
} else {
|
||||
close(fd);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int compute_sha256_of_public_key_nonce(X509 *cert, u_char *nonce, u_char *hash) {
|
||||
EVP_PKEY *pkey = NULL;
|
||||
u_char *pubkey_buf = NULL;
|
||||
u_char *concatinated = NULL;
|
||||
int pubkey_len = 0;
|
||||
int totla_len = 0;
|
||||
|
||||
pkey = X509_get_pubkey(cert);
|
||||
if (pkey == NULL) {
|
||||
fprintf(stderr, "Failed to extract public key from certificate\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
|
||||
if (pubkey_len <= 0) {
|
||||
fprintf(stderr, "Failed to convert public key to DER format\n");
|
||||
EVP_PKEY_free(pkey);
|
||||
return -1;
|
||||
}
|
||||
|
||||
totla_len = pubkey_len + CLIENT_RANDOM_SIZE;
|
||||
concatinated = (u_char*)malloc(totla_len);
|
||||
if (concatinated == NULL) {
|
||||
perror("failed to allocate memory");
|
||||
return -1;
|
||||
}
|
||||
memcpy(concatinated, nonce, CLIENT_RANDOM_SIZE);
|
||||
memcpy(concatinated + CLIENT_RANDOM_SIZE, pubkey_buf, pubkey_len);
|
||||
|
||||
// Compute the SHA-512 hash of the DER-encoded public key and the random nonce
|
||||
SHA512(concatinated, totla_len, hash);
|
||||
|
||||
// Clean up
|
||||
EVP_PKEY_free(pkey);
|
||||
OPENSSL_free(pubkey_buf);
|
||||
free(concatinated);
|
||||
|
||||
return 0; // Success
|
||||
}
|
||||
|
||||
/*
|
||||
Evidence request extension
|
||||
- Contains a random nonce that goes into the attestation report
|
||||
- Is sent in the ClientHello message
|
||||
*/
|
||||
void evidence_request_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char *out,
|
||||
void *add_arg)
|
||||
{
|
||||
free((void *)out);
|
||||
}
|
||||
|
||||
int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char **out,
|
||||
size_t *outlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *add_arg)
|
||||
{
|
||||
switch (context)
|
||||
{
|
||||
case SSL_EXT_CLIENT_HELLO:
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
|
||||
evidence_request *er = (evidence_request*)malloc(sizeof(evidence_request));
|
||||
|
||||
if (er == NULL) {
|
||||
perror("could not allocate memory");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (ext_data != NULL) {
|
||||
if (RAND_bytes(ext_data->er.data, CLIENT_RANDOM_SIZE) != 1) {
|
||||
perror("could not generate random bytes, will use SSL client random");
|
||||
SSL_get_client_random(s, ext_data->er.data, CLIENT_RANDOM_SIZE);
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "add_arg is NULL\n");
|
||||
free(er);
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
memcpy(er->data, ext_data->er.data, CLIENT_RANDOM_SIZE);
|
||||
er->tee_type = AMD_TEE;
|
||||
ext_data->er.tee_type = AMD_TEE;
|
||||
|
||||
*out = (const u_char *)er;
|
||||
*outlen = sizeof(evidence_request);
|
||||
return 1;
|
||||
}
|
||||
case SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS:
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
int32_t *platform_type = (int32_t*)malloc(sizeof(int32_t));
|
||||
|
||||
if (platform_type == NULL) {
|
||||
perror("could not allocate memory");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (check_sev_snp() > 0) {
|
||||
*platform_type = AMD_TEE;
|
||||
} else {
|
||||
*platform_type = NO_TEE;
|
||||
}
|
||||
|
||||
if ((*platform_type != ext_data->er.tee_type) || (*platform_type == NO_TEE)) {
|
||||
*platform_type = NO_TEE;
|
||||
ext_data->er.tee_type = NO_TEE;
|
||||
} else {
|
||||
ext_data->er.tee_type = AMD_TEE;
|
||||
ext_data->fetch_attestation_handler = fetchAttestationCallback(ext_data->er.tee_type);
|
||||
}
|
||||
|
||||
*out = (u_char*)platform_type;
|
||||
*outlen = sizeof(int32_t);
|
||||
} else {
|
||||
fprintf(stderr, "add_arg is NULL\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "bad context\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char *in,
|
||||
size_t inlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *parse_arg)
|
||||
{
|
||||
switch (context)
|
||||
{
|
||||
case SSL_EXT_CLIENT_HELLO:
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
|
||||
evidence_request *er = (evidence_request*)in;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
memcpy(ext_data->er.data, er->data, CLIENT_RANDOM_SIZE);
|
||||
ext_data->er.tee_type = er->tee_type;
|
||||
} else {
|
||||
fprintf(stderr, "parse_arg is NULL\n");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
case SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS:
|
||||
{
|
||||
int *tee_type = (int*)in;
|
||||
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
ext_data->er.tee_type = *tee_type;
|
||||
|
||||
if (ext_data->er.tee_type != NO_TEE) {
|
||||
ext_data->verification_validation_handler = validationVerificationCallback(ext_data->er.tee_type);
|
||||
} else {
|
||||
fprintf(stderr, "must use a TEE for aTLS\n");
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "parse_arg is NULL\n");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "bad context\n");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Attestation Certificate extension
|
||||
- Contains the attestation report
|
||||
- The attestation report contains the hash of the nonce and the Public Key of the x.509 Agent certificate
|
||||
*/
|
||||
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char *out,
|
||||
void *add_arg)
|
||||
{
|
||||
free((void *)out);
|
||||
}
|
||||
|
||||
int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char **out,
|
||||
size_t *outlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *add_arg)
|
||||
{
|
||||
switch (context)
|
||||
{
|
||||
case SSL_EXT_CLIENT_HELLO:
|
||||
return 1;
|
||||
case SSL_EXT_TLS1_3_CERTIFICATE:
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
|
||||
if (ext_data != NULL) {
|
||||
u_char *attestation_report;
|
||||
u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char));
|
||||
|
||||
if (hash == NULL) {
|
||||
perror("could not allocate memory");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (x != NULL) {
|
||||
int ret = compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash);
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "error while calculating hash\n");
|
||||
free(hash);
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "agent certificate must be used for aTLS\n");
|
||||
free(hash);
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
attestation_report = triggerFetchAttestationCallback(ext_data->fetch_attestation_handler, hash);
|
||||
if (attestation_report == NULL) {
|
||||
fprintf(stderr, "attestation report is NULL\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
free(hash);
|
||||
|
||||
*out = attestation_report;
|
||||
*outlen = ATTESTATION_REPORT_SIZE;
|
||||
return 1;
|
||||
} else {
|
||||
fprintf(stderr, "add_arg is NULL\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "bad context\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const u_char *in,
|
||||
size_t inlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *parse_arg)
|
||||
{
|
||||
switch (context)
|
||||
{
|
||||
case SSL_EXT_CLIENT_HELLO:
|
||||
// Return 1 so the server can return the custom certificate extension.
|
||||
return 1;
|
||||
case SSL_EXT_TLS1_3_CERTIFICATE:
|
||||
{
|
||||
if (x != NULL) {
|
||||
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
char *attestation_report = (char*)malloc(ATTESTATION_REPORT_SIZE*sizeof(char));
|
||||
u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char));
|
||||
int res = 0;
|
||||
|
||||
if (hash == NULL || attestation_report == NULL) {
|
||||
perror("could not allocate memory");
|
||||
|
||||
if (hash != NULL) free(hash);
|
||||
if (attestation_report != NULL) free(attestation_report);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash) != 0) {
|
||||
fprintf(stderr, "calculating hash failed\n");
|
||||
free(attestation_report);
|
||||
free(hash);
|
||||
return 0;
|
||||
}
|
||||
|
||||
memcpy(attestation_report, in, inlen);
|
||||
|
||||
res = triggerVerificationValidationCallback(ext_data->verification_validation_handler,
|
||||
attestation_report,
|
||||
ATTESTATION_REPORT_SIZE,
|
||||
hash);
|
||||
free(attestation_report);
|
||||
free(hash);
|
||||
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "verification and validation failed, aborting connection\n");
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "parse_arg is NULL\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1;
|
||||
} else {
|
||||
fprintf(stderr, "agent certificates must be used for aTLS\n");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "bad context\n");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
#ifndef ATLS_EXTENSION_H
|
||||
#define ATLS_EXTENSION_H
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
#include <arpa/inet.h>
|
||||
|
||||
#define EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE 65
|
||||
#define ATTESTATION_CERTIFICATE_EXTENSION_TYPE 66
|
||||
#define ATTESTATION_REPORT_SIZE 0x4A0
|
||||
#define REPORT_DATA_SIZE 64
|
||||
#define CLIENT_RANDOM_SIZE 32
|
||||
#define TLS_CLIENT_CTX 0
|
||||
#define TLS_SERVER_CTX 1
|
||||
|
||||
#define SEV_GUEST_DRIVER_PATH "/dev/sev-guest"
|
||||
#define NO_TEE 0
|
||||
#define AMD_TEE 1
|
||||
|
||||
typedef struct evidence_request
|
||||
{
|
||||
int tee_type;
|
||||
char data[CLIENT_RANDOM_SIZE];
|
||||
} evidence_request;
|
||||
|
||||
typedef struct tls_extension_data
|
||||
{
|
||||
uintptr_t fetch_attestation_handler;
|
||||
uintptr_t verification_validation_handler;
|
||||
evidence_request er;
|
||||
} tls_extension_data;
|
||||
|
||||
typedef struct tls_server_connection
|
||||
{
|
||||
int server_fd;
|
||||
char* cert;
|
||||
int cert_len;
|
||||
char* key;
|
||||
int key_len;
|
||||
struct sockaddr_storage addr;
|
||||
uintptr_t fetch_attestation_handler;
|
||||
} tls_server_connection;
|
||||
|
||||
typedef struct tls_connection
|
||||
{
|
||||
SSL_CTX *ctx;
|
||||
SSL *ssl;
|
||||
int socket_fd;
|
||||
struct sockaddr_storage local_addr;
|
||||
struct sockaddr_storage remote_addr;
|
||||
tls_extension_data tls_ext_data;
|
||||
} tls_connection;
|
||||
|
||||
tls_server_connection* start_tls_server(const char* cert, int cert_len, const char* key, int key_len, const char* ip, int port);
|
||||
tls_connection* tls_server_accept(tls_server_connection *tls_server);
|
||||
int tls_server_close(tls_server_connection *tls_server);
|
||||
int tls_read(tls_connection *conn, void *buf, int num);
|
||||
int tls_write(tls_connection *conn, const void *buf, int num);
|
||||
int tls_close(tls_connection *conn);
|
||||
tls_connection* new_tls_connection(char *address, int port);
|
||||
int set_socket_read_timeout(tls_connection* conn, int timeout_sec, int timeout_usec);
|
||||
int set_socket_write_timeout(tls_connection* conn, int timeout_sec, int timeout_usec);
|
||||
char* tls_return_addr(struct sockaddr_storage *addr);
|
||||
int tls_return_port(struct sockaddr_storage *addr);
|
||||
int compute_sha256_of_public_key(X509 *cert, unsigned char *hash);
|
||||
|
||||
// Extensions
|
||||
void evidence_request_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char *out,
|
||||
void *add_arg);
|
||||
int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char **out,
|
||||
size_t *outlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *add_arg);
|
||||
int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char *in,
|
||||
size_t inlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *parse_arg);
|
||||
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char *out,
|
||||
void *add_arg);
|
||||
int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char **out,
|
||||
size_t *outlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *add_arg);
|
||||
int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
const unsigned char *in,
|
||||
size_t inlen, X509 *x,
|
||||
size_t chainidx, int *al,
|
||||
void *parse_arg);
|
||||
|
||||
#endif // ATLS_EXTENSION_H
|
||||
@@ -0,0 +1,586 @@
|
||||
#include "extensions.h"
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/x509.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <stdio.h>
|
||||
#include <netdb.h>
|
||||
#include <unistd.h>
|
||||
#include <string.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/time.h>
|
||||
#include <ifaddrs.h>
|
||||
|
||||
void init_openssl() {
|
||||
SSL_load_error_strings();
|
||||
OpenSSL_add_ssl_algorithms();
|
||||
}
|
||||
|
||||
void cleanup_openssl() {
|
||||
EVP_cleanup();
|
||||
}
|
||||
|
||||
int load_certificates_from_memory(SSL_CTX* ctx, const char* cert, int cert_len, const char* key, int key_len) {
|
||||
BIO* cert_bio = NULL;
|
||||
BIO* key_bio = NULL;
|
||||
X509* x509_cert = NULL;
|
||||
EVP_PKEY* pkey = NULL;
|
||||
int success = 0;
|
||||
|
||||
cert_bio = BIO_new_mem_buf(cert, cert_len);
|
||||
if (cert_bio == NULL) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
key_bio = BIO_new_mem_buf(key, key_len);
|
||||
if (key_bio == NULL) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
x509_cert = PEM_read_bio_X509(cert_bio, NULL, 0, NULL);
|
||||
if (x509_cert == NULL) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
pkey = PEM_read_bio_PrivateKey(key_bio, NULL, 0, NULL);
|
||||
if (pkey == NULL) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (SSL_CTX_use_certificate(ctx, x509_cert) <= 0) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (SSL_CTX_use_PrivateKey(ctx, pkey) <= 0) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (SSL_CTX_check_private_key(ctx) <= 0) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
success = 1;
|
||||
|
||||
cleanup:
|
||||
if (cert_bio) BIO_free(cert_bio);
|
||||
if (key_bio) BIO_free(key_bio);
|
||||
if (x509_cert) X509_free(x509_cert);
|
||||
if (pkey) EVP_PKEY_free(pkey);
|
||||
|
||||
if (!success) {
|
||||
ERR_print_errors_fp(stderr);
|
||||
}
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
int enforce_tls1_3_only(SSL_CTX *ctx) {
|
||||
if (SSL_CTX_set_min_proto_version(ctx, TLS1_3_VERSION) == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION) == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
SSL_CTX *create_context(int is_server) {
|
||||
const SSL_METHOD *method;
|
||||
SSL_CTX *ctx;
|
||||
|
||||
if (is_server) {
|
||||
method = TLS_server_method();
|
||||
} else {
|
||||
method = TLS_client_method();
|
||||
}
|
||||
|
||||
ctx = SSL_CTX_new(method);
|
||||
if (!ctx) {
|
||||
perror("Unable to create SSL context");
|
||||
ERR_print_errors_fp(stderr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!enforce_tls1_3_only(ctx)) {
|
||||
fprintf(stderr, "could not enforce TLS1.3\n");
|
||||
SSL_CTX_free(ctx);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
int add_custom_tls_extension(SSL_CTX *ctx, tls_connection *conn) {
|
||||
uint32_t flags_nonce = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS;
|
||||
uint32_t flags_attestation = SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_CERTIFICATE;
|
||||
int ret = 1;
|
||||
void *data = NULL;
|
||||
|
||||
if (conn != NULL) {
|
||||
data = (void*)&conn->tls_ext_data;
|
||||
}
|
||||
|
||||
ret = SSL_CTX_add_custom_ext(ctx,
|
||||
EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE,
|
||||
flags_nonce,
|
||||
evidence_request_ext_add_cb,
|
||||
evidence_request_ext_free_cb,
|
||||
data,
|
||||
evidence_request_ext_parse_cb,
|
||||
data);
|
||||
if (ret != 1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
ret = SSL_CTX_add_custom_ext(ctx,
|
||||
ATTESTATION_CERTIFICATE_EXTENSION_TYPE,
|
||||
flags_attestation,
|
||||
attestation_certificate_ext_add_cb,
|
||||
attestation_certificate_ext_free_cb,
|
||||
data,
|
||||
attestation_certificate_ext_parse_cb,
|
||||
data);
|
||||
if (ret != 1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Function to start the TLS server
|
||||
tls_server_connection* start_tls_server(const char* cert, int cert_len, const char* key, int key_len, const char* ip, int port) {
|
||||
tls_server_connection *tls_server = (tls_server_connection*)malloc(sizeof(tls_server_connection));
|
||||
int opt = 0;
|
||||
|
||||
if (tls_server == NULL) {
|
||||
perror("memory could not be allocated");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
init_openssl();
|
||||
|
||||
tls_server->cert = (char*)malloc(cert_len * sizeof(char));
|
||||
if (tls_server->cert == NULL) {
|
||||
perror("memory could not be allocated");
|
||||
goto cleanup_tls_server;
|
||||
}
|
||||
|
||||
tls_server->key = (char*)malloc(key_len * sizeof(char));
|
||||
if (tls_server->key == NULL) {
|
||||
perror("memory could not be allocated");
|
||||
goto cleanup_cert;
|
||||
}
|
||||
|
||||
memcpy(tls_server->cert, cert, cert_len);
|
||||
memcpy(tls_server->key, key, key_len);
|
||||
tls_server->cert_len = cert_len;
|
||||
tls_server->key_len = key_len;
|
||||
|
||||
tls_server->server_fd = socket(AF_INET6, SOCK_STREAM, 0);
|
||||
if (tls_server->server_fd < 0) {
|
||||
perror("Unable to create socket");
|
||||
goto cleanup_key;
|
||||
}
|
||||
|
||||
// Enable both IPv4-mapped and IPv6 addresses
|
||||
if (setsockopt(tls_server->server_fd, IPPROTO_IPV6, IPV6_V6ONLY, &opt, sizeof(opt)) != 0) {
|
||||
perror("setsockopt(IPV6_V6ONLY) failed");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
// Configure address structure
|
||||
memset(&(tls_server->addr), 0, sizeof(tls_server->addr));
|
||||
struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)&tls_server->addr;
|
||||
addr6->sin6_family = AF_INET6;
|
||||
addr6->sin6_port = htons(port);
|
||||
|
||||
// Set the appropriate address (IPv4-mapped if needed)
|
||||
if (ip == NULL || (strlen(ip) + 1) < INET_ADDRSTRLEN) {
|
||||
addr6->sin6_addr = in6addr_any;
|
||||
} else if (strchr(ip, ':') != NULL) {
|
||||
if (inet_pton(AF_INET6, ip, &(addr6->sin6_addr)) <= 0) {
|
||||
perror("Invalid IPv6 address");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
} else {
|
||||
struct in_addr ipv4_addr;
|
||||
if (inet_pton(AF_INET, ip, &ipv4_addr) <= 0) {
|
||||
perror("Invalid IPv4 address");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
memset(&addr6->sin6_addr, 0, sizeof(addr6->sin6_addr));
|
||||
addr6->sin6_addr.s6_addr[10] = 0xff;
|
||||
addr6->sin6_addr.s6_addr[11] = 0xff;
|
||||
memcpy(&addr6->sin6_addr.s6_addr[12], &ipv4_addr, sizeof(ipv4_addr));
|
||||
}
|
||||
|
||||
if (bind(tls_server->server_fd, (struct sockaddr*)&(tls_server->addr), sizeof(tls_server->addr)) < 0) {
|
||||
perror("Unable to bind");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
if (listen(tls_server->server_fd, SOMAXCONN) < 0) {
|
||||
perror("Unable to listen");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
printf("Listening on port: %d\n", port);
|
||||
return tls_server;
|
||||
|
||||
// Cleanup labels
|
||||
cleanup_socket:
|
||||
close(tls_server->server_fd);
|
||||
cleanup_key:
|
||||
free(tls_server->key);
|
||||
cleanup_cert:
|
||||
free(tls_server->cert);
|
||||
cleanup_tls_server:
|
||||
free(tls_server);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Function to accept a client connection
|
||||
tls_connection* tls_server_accept(tls_server_connection *tls_server) {
|
||||
uint32_t len = sizeof(struct sockaddr_storage);
|
||||
tls_connection *conn = (tls_connection*)malloc(sizeof(tls_connection));
|
||||
int client_fd = -1;
|
||||
int ret = 0;
|
||||
|
||||
if (conn == NULL) {
|
||||
perror("Unable to allocate memory for tls_connection");
|
||||
return NULL;
|
||||
}
|
||||
conn->ctx = NULL;
|
||||
conn->ssl = NULL;
|
||||
conn->socket_fd = -1;
|
||||
|
||||
// Initialize the context
|
||||
conn->ctx = create_context(TLS_SERVER_CTX);
|
||||
if (conn->ctx == NULL) {
|
||||
perror("Unable to create context");
|
||||
goto cleanup_conn;
|
||||
}
|
||||
|
||||
// Load certificates
|
||||
if (!load_certificates_from_memory(conn->ctx, tls_server->cert, tls_server->cert_len, tls_server->key, tls_server->key_len)) {
|
||||
fprintf(stderr, "Failed to load certificates\n");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Add custom TLS extension
|
||||
ret = add_custom_tls_extension(conn->ctx, conn);
|
||||
if (!ret) {
|
||||
perror("Unable to add custom tls extensions");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Accept client connection
|
||||
client_fd = accept(tls_server->server_fd, NULL, NULL);
|
||||
if (client_fd < 0) {
|
||||
perror("Unable to accept connection");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Create SSL object
|
||||
conn->ssl = SSL_new(conn->ctx);
|
||||
if (conn->ssl == NULL) {
|
||||
perror("Unable to create SSL object");
|
||||
goto cleanup_fd;
|
||||
}
|
||||
|
||||
// Set file descriptor and assign handlers
|
||||
conn->socket_fd = client_fd;
|
||||
conn->tls_ext_data.fetch_attestation_handler = tls_server->fetch_attestation_handler;
|
||||
SSL_set_fd(conn->ssl, client_fd);
|
||||
|
||||
// Get local address
|
||||
if (getsockname(client_fd, (struct sockaddr *)&conn->local_addr, &len) == -1) {
|
||||
perror("getsockname failed during TLS server accept");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
|
||||
// Get remote address
|
||||
if (getpeername(client_fd, (struct sockaddr *)&conn->remote_addr, &len) == -1) {
|
||||
perror("getpeername failed during TLS server accept");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
|
||||
// Perform SSL handshake
|
||||
ret = SSL_accept(conn->ssl);
|
||||
if (ret <= 0) {
|
||||
perror("SSL handshake failed during accept");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
|
||||
return conn;
|
||||
|
||||
cleanup_ssl:
|
||||
if (conn->ssl) SSL_free(conn->ssl);
|
||||
cleanup_fd:
|
||||
if (client_fd >= 0) close(client_fd);
|
||||
cleanup_ctx:
|
||||
if (conn->ctx) SSL_CTX_free(conn->ctx);
|
||||
cleanup_conn:
|
||||
free(conn);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Function to close the server
|
||||
int tls_server_close(tls_server_connection *tls_server) {
|
||||
close(tls_server->server_fd);
|
||||
cleanup_openssl();
|
||||
free(tls_server->cert);
|
||||
free(tls_server->key);
|
||||
free(tls_server);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tls_read(tls_connection *conn, void *buf, int num) {
|
||||
return SSL_read(conn->ssl, buf, num);
|
||||
}
|
||||
|
||||
int tls_write(tls_connection *conn, const void *buf, int num) {
|
||||
if (SSL_get_shutdown(conn->ssl) & SSL_SENT_SHUTDOWN) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return SSL_write(conn->ssl, buf, num);
|
||||
}
|
||||
|
||||
int tls_close(tls_connection *conn) {
|
||||
if (conn != NULL) {
|
||||
if (conn->ssl != NULL) {
|
||||
int ret = 0;
|
||||
|
||||
while (ret == 0) {
|
||||
ret = SSL_shutdown(conn->ssl);
|
||||
|
||||
if (ret < 0) {
|
||||
fprintf(stderr, "SSL did not shutdown correctly\n");
|
||||
free(conn);
|
||||
close(conn->socket_fd);
|
||||
conn = NULL;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
conn->ssl = NULL;
|
||||
}
|
||||
if (conn->socket_fd >= 0) {
|
||||
close(conn->socket_fd);
|
||||
conn->socket_fd = -1;
|
||||
}
|
||||
SSL_free(conn->ssl);
|
||||
if (conn->ctx != NULL) {
|
||||
SSL_CTX_free(conn->ctx);
|
||||
conn->ctx = NULL;
|
||||
}
|
||||
|
||||
free(conn);
|
||||
conn = NULL;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
char* tls_return_addr(struct sockaddr_storage *addr) {
|
||||
socklen_t addr_len = sizeof(struct sockaddr_storage);
|
||||
int inet_size =addr->ss_family == AF_INET ? INET_ADDRSTRLEN : INET6_ADDRSTRLEN;
|
||||
char *ip_str = (char*)malloc(inet_size*sizeof(char));
|
||||
void * addr_ptr;
|
||||
|
||||
if (ip_str == NULL) {
|
||||
perror("memory could not be allocated");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (addr->ss_family == AF_INET) {
|
||||
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
|
||||
addr_ptr = &(ipv4->sin_addr);
|
||||
} else if (addr->ss_family == AF_INET6) {
|
||||
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
|
||||
addr_ptr = &(ipv6->sin6_addr);
|
||||
} else {
|
||||
fprintf(stderr, "unknown family: %d\n", addr->ss_family);
|
||||
free(ip_str);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (inet_ntop(addr->ss_family, addr_ptr, ip_str, inet_size) == NULL) {
|
||||
perror("inet_ntop failed");
|
||||
free(ip_str);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return ip_str;
|
||||
}
|
||||
|
||||
int tls_return_port(struct sockaddr_storage *addr) {
|
||||
if (addr->ss_family == AF_INET) {
|
||||
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
|
||||
return ntohs(ipv4->sin_port);
|
||||
} else if (addr->ss_family == AF_INET6) {
|
||||
struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
|
||||
return ntohs(ipv6->sin6_port);
|
||||
}
|
||||
|
||||
fprintf(stderr, "cannot return port from unknown family: %d\n", addr->ss_family);
|
||||
return -1;
|
||||
}
|
||||
|
||||
tls_connection* new_tls_connection(char *address, int port) {
|
||||
SSL_CTX *ctx = NULL;
|
||||
SSL *ssl = NULL;
|
||||
int socket_fd = -1;
|
||||
int status;
|
||||
struct addrinfo hints, *res = NULL, *p = NULL;
|
||||
char port_str[6];
|
||||
tls_connection *conn = NULL;
|
||||
socklen_t addr_len;
|
||||
int ret = 0;
|
||||
|
||||
conn = (tls_connection*)malloc(sizeof(tls_connection));
|
||||
if (!conn) {
|
||||
perror("Failed to allocate memory for atls connection");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Format the port string
|
||||
snprintf(port_str, sizeof(port_str), "%d", port);
|
||||
|
||||
// Initialize OpenSSL
|
||||
init_openssl();
|
||||
|
||||
// Create SSL context
|
||||
ctx = create_context(TLS_CLIENT_CTX);
|
||||
if (!ctx) {
|
||||
perror("Could not create context");
|
||||
goto cleanup_conn;
|
||||
}
|
||||
|
||||
conn->ctx = ctx;
|
||||
// Add custom TLS extension
|
||||
ret = add_custom_tls_extension(conn->ctx, conn);
|
||||
if (!ret) {
|
||||
perror("Unable to add custom tls extensions");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Prepare the address info hints
|
||||
memset(&hints, 0, sizeof(hints));
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
hints.ai_flags = 0;
|
||||
hints.ai_protocol = 0;
|
||||
|
||||
// Get address info
|
||||
status = getaddrinfo(address, port_str, &hints, &res);
|
||||
if (status != 0) {
|
||||
perror("getaddrinfo error");
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
// Iterate through the results and try to connect
|
||||
for (p = res; p != NULL; p = p->ai_next) {
|
||||
socket_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
|
||||
if (socket_fd < 0) {
|
||||
perror("unable to create socket");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (connect(socket_fd, p->ai_addr, p->ai_addrlen)) {
|
||||
close(socket_fd);
|
||||
continue;
|
||||
}
|
||||
|
||||
memcpy(&conn->local_addr, p->ai_addr, p->ai_addrlen);
|
||||
break;
|
||||
}
|
||||
|
||||
freeaddrinfo(res);
|
||||
|
||||
if (p == NULL) {
|
||||
goto cleanup_ctx;
|
||||
}
|
||||
|
||||
conn->socket_fd = socket_fd;
|
||||
|
||||
// Retrieve and store the remote address
|
||||
addr_len = sizeof(conn->remote_addr);
|
||||
if (getpeername(socket_fd, (struct sockaddr *)&conn->remote_addr, &addr_len) == -1) {
|
||||
perror("getpeername failed");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
// Create the SSL structure
|
||||
ssl = SSL_new(ctx);
|
||||
if (!ssl) {
|
||||
perror("Failed to create SSL object");
|
||||
goto cleanup_socket;
|
||||
}
|
||||
|
||||
if (!SSL_set_fd(ssl, socket_fd)) {
|
||||
perror("Failed to set SSL file descriptor");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
conn->ssl = ssl;
|
||||
|
||||
// Perform the SSL handshake
|
||||
if (SSL_connect(ssl) <= 0) {
|
||||
fprintf(stderr, "SSL handshake failed\n");
|
||||
goto cleanup_ssl;
|
||||
}
|
||||
|
||||
return conn;
|
||||
|
||||
cleanup_ssl:
|
||||
if (ssl) SSL_free(ssl);
|
||||
cleanup_socket:
|
||||
if (socket_fd >= 0) close(socket_fd);
|
||||
cleanup_ctx:
|
||||
if (ctx) SSL_CTX_free(ctx);
|
||||
cleanup_conn:
|
||||
free(conn);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int set_socket_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
|
||||
struct timeval timeout;
|
||||
timeout.tv_sec = timeout_sec;
|
||||
timeout.tv_usec = timeout_usec;
|
||||
|
||||
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int set_socket_read_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
|
||||
struct timeval timeout;
|
||||
timeout.tv_sec = timeout_sec;
|
||||
timeout.tv_usec = timeout_usec;
|
||||
|
||||
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int set_socket_write_timeout(tls_connection* conn, int timeout_sec, int timeout_usec) {
|
||||
struct timeval timeout;
|
||||
timeout.tv_sec = timeout_sec;
|
||||
timeout.tv_usec = timeout_usec;
|
||||
|
||||
if (setsockopt(conn->socket_fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !embed
|
||||
// +build !embed
|
||||
|
||||
package quoteprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/client"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-sev-guest/validate"
|
||||
"github.com/google/go-sev-guest/verify"
|
||||
"github.com/google/go-sev-guest/verify/trust"
|
||||
"github.com/google/logger"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
cocosDirectory = ".cocos"
|
||||
caBundleName = "ask_ark.pem"
|
||||
attestationReportSize = 0x4A0
|
||||
reportDataSize = 64
|
||||
sevProductNameMilan = "Milan"
|
||||
sevProductNameGenoa = "Genoa"
|
||||
)
|
||||
|
||||
var (
|
||||
AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
timeout = time.Minute * 2
|
||||
maxTryDelay = time.Second * 30
|
||||
)
|
||||
|
||||
var (
|
||||
errProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevProductNameMilan, sevProductNameGenoa))
|
||||
errReportSize = errors.New("attestation report size mismatch")
|
||||
errAttVerification = errors.New("attestation verification failed")
|
||||
errAttValidation = errors.New("attestation validation failed")
|
||||
)
|
||||
|
||||
func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) error {
|
||||
product := cfg.RootOfTrust.ProductLine
|
||||
|
||||
chain := attestation.GetCertificateChain()
|
||||
if chain == nil {
|
||||
chain = &sevsnp.CertificateChain{}
|
||||
attestation.CertificateChain = chain
|
||||
}
|
||||
if len(chain.GetAskCert()) == 0 || len(chain.GetArkCert()) == 0 {
|
||||
homePath, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bundlePath := path.Join(homePath, cocosDirectory, product, caBundleName)
|
||||
if _, err := os.Stat(bundlePath); err == nil {
|
||||
amdRootCerts := trust.AMDRootCerts{}
|
||||
if err := amdRootCerts.FromKDSCert(bundlePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chain.ArkCert = amdRootCerts.ProductCerts.Ark.Raw
|
||||
chain.AskCert = amdRootCerts.ProductCerts.Ask.Raw
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyConfig(attConf *check.Config) (*check.Config, error) {
|
||||
copy := proto.Clone(attConf).(*check.Config)
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(errAttVerification, err))
|
||||
}
|
||||
|
||||
if cfg.Policy.Product == nil {
|
||||
productName := sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN
|
||||
switch cfg.RootOfTrust.ProductLine {
|
||||
case sevProductNameMilan:
|
||||
productName = sevsnp.SevProduct_SEV_PRODUCT_MILAN
|
||||
case sevProductNameGenoa:
|
||||
productName = sevsnp.SevProduct_SEV_PRODUCT_GENOA
|
||||
default:
|
||||
}
|
||||
|
||||
if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN {
|
||||
return errProductLine
|
||||
}
|
||||
|
||||
sopts.Product = &sevsnp.SevProduct{
|
||||
Name: productName,
|
||||
}
|
||||
} else {
|
||||
sopts.Product = cfg.Policy.Product
|
||||
}
|
||||
|
||||
sopts.Getter = &trust.RetryHTTPSGetter{
|
||||
Timeout: timeout,
|
||||
MaxRetryDelay: maxTryDelay,
|
||||
Getter: &trust.SimpleHTTPSGetter{},
|
||||
}
|
||||
|
||||
if err := fillInAttestationLocal(attestationPB, cfg); err != nil {
|
||||
return fmt.Errorf("failed to fill the attestation with local ARK and ASK certificates %v", err)
|
||||
}
|
||||
|
||||
if err := verify.SnpAttestation(attestationPB, sopts); err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
opts, err := validate.PolicyToOptions(cfg.Policy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get policy for validation %v", errors.Wrap(errAttVerification, err))
|
||||
}
|
||||
|
||||
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
|
||||
return errors.Wrap(errAttValidation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetQuoteProvider() (client.QuoteProvider, error) {
|
||||
return client.GetQuoteProvider()
|
||||
}
|
||||
|
||||
func VerifyAttestationReportTLS(attestationBytes []byte, reportData []byte) error {
|
||||
config, err := copyConfig(&AttConfigurationSEVSNP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create a copy of backend configuration")
|
||||
}
|
||||
|
||||
config.Policy.ReportData = reportData[:]
|
||||
return VerifyAndValidate(attestationBytes, config)
|
||||
}
|
||||
|
||||
func VerifyAndValidate(attestationReport []byte, cfg *check.Config) error {
|
||||
logger.Init("", false, false, io.Discard)
|
||||
|
||||
if len(attestationReport) < attestationReportSize {
|
||||
return errReportSize
|
||||
}
|
||||
attestationBytes := attestationReport[:attestationReportSize]
|
||||
|
||||
attestationPB, err := abi.ReportCertsToProto(attestationBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert attestation bytes to struct %v", errors.Wrap(errAttVerification, err))
|
||||
}
|
||||
|
||||
if err = verifyReport(attestationPB, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = validateReport(attestationPB, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchAttestation(reportDataSlice []byte) ([]byte, error) {
|
||||
var reportData [reportDataSize]byte
|
||||
|
||||
qp, err := GetQuoteProvider()
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("could not get quote provider")
|
||||
}
|
||||
|
||||
if len(reportData) > reportDataSize {
|
||||
return []byte{}, fmt.Errorf("attestation report size mismatch")
|
||||
}
|
||||
copy(reportData[:], reportDataSlice)
|
||||
|
||||
rawQuote, err := qp.GetRawQuote(reportData)
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("failed to get raw quote")
|
||||
}
|
||||
|
||||
return rawQuote, nil
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !embed
|
||||
// +build !embed
|
||||
|
||||
package quoteprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const (
|
||||
measurementOffset = 0x90
|
||||
signatureOffset = 0x2A0
|
||||
)
|
||||
|
||||
func TestFillInAttestationLocal(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "test_home")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
cocosDir := tempDir + "/.cocos/Milan"
|
||||
err = os.MkdirAll(cocosDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
bundleContent := []byte("mock ASK ARK bundle")
|
||||
err = os.WriteFile(cocosDir+"/ask_ark.pem", bundleContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
Policy: &check.Policy{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Empty attestation",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Attestation with existing chain",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("existing ASK cert"),
|
||||
ArkCert: []byte("existing ARK cert"),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := fillInAttestationLocal(tt.attestation, &config)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportSuccess(t *testing.T) {
|
||||
file, reportData := prepareForTestVerifyAttestationReport(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationReport []byte
|
||||
reportData []byte
|
||||
goodProduct int
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation, validation and verification is performed succsessfully",
|
||||
attestationReport: file,
|
||||
reportData: reportData,
|
||||
goodProduct: 1,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
file, reportData := prepareForTestVerifyAttestationReport(t)
|
||||
|
||||
// Change random data so in the signature so the signature failes
|
||||
file[signatureOffset] = file[signatureOffset] ^ 0x01
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationReport []byte
|
||||
reportData []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation, distorted signature",
|
||||
attestationReport: file,
|
||||
reportData: reportData,
|
||||
err: errAttVerification,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
|
||||
file, reportData := prepareForTestVerifyAttestationReport(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationReport []byte
|
||||
reportData []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation, unknown product",
|
||||
attestationReport: file,
|
||||
reportData: reportData,
|
||||
err: errProductLine,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
AttConfigurationSEVSNP.RootOfTrust.ProductLine = ""
|
||||
AttConfigurationSEVSNP.Policy.Product = nil
|
||||
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportMalformedPolicy(t *testing.T) {
|
||||
file, reportData := prepareForTestVerifyAttestationReport(t)
|
||||
|
||||
// Change random data in the measurement so the measurement does not match
|
||||
file[measurementOffset] = file[measurementOffset] ^ 0x01
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationReport []byte
|
||||
reportData []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation, malformed policy (measurement)",
|
||||
attestationReport: file,
|
||||
reportData: reportData,
|
||||
err: errAttVerification,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func prepareForTestVerifyAttestationReport(t *testing.T) ([]byte, []byte) {
|
||||
file, err := os.ReadFile("../../../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
rr, err := abi.ReportCertsToProto(file)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(file) < attestationReportSize {
|
||||
file = append(file, make([]byte, attestationReportSize-len(file))...)
|
||||
}
|
||||
|
||||
AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
|
||||
backendinfoFile, err := os.ReadFile("../../../scripts/backend_info/backend_info.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = protojson.Unmarshal(backendinfoFile, &AttConfigurationSEVSNP)
|
||||
require.NoError(t, err)
|
||||
|
||||
AttConfigurationSEVSNP.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
|
||||
AttConfigurationSEVSNP.Policy.FamilyId = rr.Report.FamilyId
|
||||
AttConfigurationSEVSNP.Policy.ImageId = rr.Report.ImageId
|
||||
AttConfigurationSEVSNP.Policy.Measurement = rr.Report.Measurement
|
||||
AttConfigurationSEVSNP.Policy.HostData = rr.Report.HostData
|
||||
AttConfigurationSEVSNP.Policy.ReportIdMa = rr.Report.ReportIdMa
|
||||
AttConfigurationSEVSNP.RootOfTrust.ProductLine = sevProductNameMilan
|
||||
|
||||
return file, rr.Report.ReportData
|
||||
}
|
||||
@@ -20,13 +20,15 @@ func NewAgentClient(ctx context.Context, cfg grpc.Config) (grpc.Client, agent.Ag
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "agent",
|
||||
})
|
||||
if client.Secure() != grpc.WithATLS {
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "agent",
|
||||
})
|
||||
|
||||
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
|
||||
return nil, nil, errors.Wrap(err, ErrAgentServiceUnavailable)
|
||||
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
|
||||
return nil, nil, errors.Wrap(err, ErrAgentServiceUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
return client, agent.NewAgentServiceClient(client.Connection()), nil
|
||||
|
||||
+43
-122
@@ -3,27 +3,24 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-sev-guest/validate"
|
||||
"github.com/google/go-sev-guest/verify"
|
||||
"github.com/google/go-sev-guest/verify/trust"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
type security int
|
||||
@@ -32,14 +29,12 @@ const (
|
||||
withoutTLS security = iota
|
||||
withTLS
|
||||
withmTLS
|
||||
withaTLS
|
||||
)
|
||||
|
||||
const (
|
||||
cocosDirectory = ".cocos"
|
||||
caBundleName = "ask_ark.pem"
|
||||
productNameMilan = "Milan"
|
||||
productNameGenoa = "Genoa"
|
||||
attestationReportSize = 0x4A0
|
||||
AttestationReportSize = 0x4A0
|
||||
WithATLS = "with aTLS"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -49,20 +44,11 @@ var (
|
||||
ErrBackendInfoMissing = errors.New("failed due to missing backend info file")
|
||||
ErrBackendInfoDecode = errors.New("failed to decode backend info file")
|
||||
errCertificateParse = errors.New("failed to parse x509 certificate")
|
||||
errAttVerification = errors.New("attestation verification failed")
|
||||
errAttValidation = errors.New("attestation validation failed")
|
||||
errCustomExtension = errors.New("failed due to missing custom extension")
|
||||
errAttVerification = errors.New("certificat is not sefl signed")
|
||||
errFailedToLoadClientCertKey = errors.New("failed to load client certificate and key")
|
||||
errFailedToLoadRootCA = errors.New("failed to load root ca file")
|
||||
)
|
||||
|
||||
var (
|
||||
customSEVSNPExtensionOID = asn1.ObjectIdentifier{1, 2, 3, 4, 5, 6}
|
||||
attestationConfiguration = AttestationConfiguration{}
|
||||
timeout = time.Minute * 2
|
||||
maxTryDelay = time.Second * 30
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ClientCert string `env:"CLIENT_CERT" envDefault:""`
|
||||
ClientKey string `env:"CLIENT_KEY" envDefault:""`
|
||||
@@ -73,11 +59,6 @@ type Config struct {
|
||||
BackendInfo string `env:"BACKEND_INFO" envDefault:""`
|
||||
}
|
||||
|
||||
type AttestationConfiguration struct {
|
||||
SNPPolicy *check.Policy `json:"snp_policy,omitempty"`
|
||||
RootOfTrust *check.RootOfTrust `json:"root_of_trust,omitempty"`
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
// Close closes gRPC connection.
|
||||
Close() error
|
||||
@@ -124,6 +105,8 @@ func (c *client) Secure() string {
|
||||
return "with TLS"
|
||||
case withmTLS:
|
||||
return "with mTLS"
|
||||
case withaTLS:
|
||||
return WithATLS
|
||||
case withoutTLS:
|
||||
fallthrough
|
||||
default:
|
||||
@@ -144,16 +127,18 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
|
||||
tc := insecure.NewCredentials()
|
||||
|
||||
if cfg.AttestedTLS {
|
||||
err := ReadBackendInfo(cfg.BackendInfo, &attestationConfiguration)
|
||||
err := ReadBackendInfo(cfg.BackendInfo, "eprovider.AttConfigurationSEVSNP)
|
||||
if err != nil {
|
||||
return nil, secure, errors.Wrap(fmt.Errorf("failed to read Backend Info"), err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
VerifyPeerCertificate: verifyAttestationReportTLS,
|
||||
VerifyPeerCertificate: verifyPeerCertificateATLS,
|
||||
}
|
||||
tc = credentials.NewTLS(tlsConfig)
|
||||
opts = append(opts, grpc.WithContextDialer(CustomDialer))
|
||||
secure = withaTLS
|
||||
} else {
|
||||
if cfg.ServerCAFile != "" {
|
||||
tlsConfig := &tls.Config{}
|
||||
@@ -195,17 +180,14 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
|
||||
return conn, secure, nil
|
||||
}
|
||||
|
||||
func ReadBackendInfo(manifestPath string, attestationConfiguration *AttestationConfiguration) error {
|
||||
func ReadBackendInfo(manifestPath string, attestationConfiguration *check.Config) error {
|
||||
if manifestPath != "" {
|
||||
manifest, err := os.Open(manifestPath)
|
||||
manifest, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(errBackendInfoOpen, err)
|
||||
}
|
||||
defer manifest.Close()
|
||||
|
||||
decoder := json.NewDecoder(manifest)
|
||||
err = decoder.Decode(attestationConfiguration)
|
||||
if err != nil {
|
||||
if err := protojson.Unmarshal(manifest, attestationConfiguration); err != nil {
|
||||
return errors.Wrap(ErrBackendInfoDecode, err)
|
||||
}
|
||||
|
||||
@@ -215,69 +197,37 @@ func ReadBackendInfo(manifestPath string, attestationConfiguration *AttestationC
|
||||
return ErrBackendInfoMissing
|
||||
}
|
||||
|
||||
func verifyAttestationReportTLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
func CustomDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
ip, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create a custom dialer")
|
||||
}
|
||||
|
||||
p, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad format of IP address: %v", err)
|
||||
}
|
||||
|
||||
conn, err := atls.DialTLSClient(ip, p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create TLS connection")
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func verifyPeerCertificateATLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return errors.Wrap(errCertificateParse, err)
|
||||
}
|
||||
|
||||
for _, ext := range cert.Extensions {
|
||||
if ext.Id.Equal(customSEVSNPExtensionOID) {
|
||||
// Check if the certificate is self-signed
|
||||
err := checkIfCertificateSelfSigned(cert)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
expectedReportData := sha3.Sum512(publicKeyBytes)
|
||||
attestationConfiguration.SNPPolicy.ReportData = expectedReportData[:]
|
||||
|
||||
// Attestation verification and validation
|
||||
sopts, err := verify.RootOfTrustToOptions(attestationConfiguration.RootOfTrust)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
sopts.Product = attestationConfiguration.SNPPolicy.Product
|
||||
sopts.Getter = &trust.RetryHTTPSGetter{
|
||||
Timeout: timeout,
|
||||
MaxRetryDelay: maxTryDelay,
|
||||
Getter: &trust.SimpleHTTPSGetter{},
|
||||
}
|
||||
|
||||
attestation_bytes := ext.Value[:attestationReportSize]
|
||||
attestationPB, err := abi.ReportCertsToProto(attestation_bytes)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
if err := fillInAttestationLocal(attestationPB); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = verify.SnpAttestation(attestationPB, sopts); err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
opts, err := validate.PolicyToOptions(attestationConfiguration.SNPPolicy)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
|
||||
return errors.Wrap(errAttValidation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
err = checkIfCertificateSelfSigned(cert)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
return errCustomExtension
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkIfCertificateSelfSigned(cert *x509.Certificate) error {
|
||||
@@ -295,32 +245,3 @@ func checkIfCertificateSelfSigned(cert *x509.Certificate) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func fillInAttestationLocal(attestation *sevsnp.Attestation) error {
|
||||
product := attestationConfiguration.RootOfTrust.ProductLine
|
||||
|
||||
chain := attestation.GetCertificateChain()
|
||||
if chain == nil {
|
||||
chain = &sevsnp.CertificateChain{}
|
||||
attestation.CertificateChain = chain
|
||||
}
|
||||
if len(chain.GetAskCert()) == 0 || len(chain.GetArkCert()) == 0 {
|
||||
homePath, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bundleFilePath := path.Join(homePath, cocosDirectory, product, caBundleName)
|
||||
if _, err := os.Stat(bundleFilePath); err == nil {
|
||||
amdRootCerts := trust.AMDRootCerts{}
|
||||
if err := amdRootCerts.FromKDSCert(bundleFilePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chain.ArkCert = amdRootCerts.ProductCerts.Ark.Raw
|
||||
chain.AskCert = amdRootCerts.ProductCerts.Ask.Raw
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -16,14 +15,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-sev-guest/tools/lib/report"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
@@ -151,7 +145,7 @@ func TestClientSecure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReadBackendInfo(t *testing.T) {
|
||||
validJSON := `{"snp_policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}`
|
||||
validJSON := `{"policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}`
|
||||
invalidJSON := `{"invalid_json"`
|
||||
|
||||
cases := []struct {
|
||||
@@ -194,12 +188,12 @@ func TestReadBackendInfo(t *testing.T) {
|
||||
defer os.Remove(tt.manifestPath)
|
||||
}
|
||||
|
||||
config := &AttestationConfiguration{}
|
||||
err := ReadBackendInfo(tt.manifestPath, config)
|
||||
config := check.Config{}
|
||||
err := ReadBackendInfo(tt.manifestPath, &config)
|
||||
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
if tt.err == nil {
|
||||
assert.NotNil(t, config.SNPPolicy)
|
||||
assert.NotNil(t, config.Policy)
|
||||
assert.NotNil(t, config.RootOfTrust)
|
||||
}
|
||||
})
|
||||
@@ -292,132 +286,6 @@ func createTempFileHandle() (*os.File, error) {
|
||||
return os.CreateTemp("", "test")
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportTLS(t *testing.T) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
file, err := os.ReadFile("../../../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
rr, err := abi.ReportCertsToProto(file)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedReportData := sha3.Sum512(publicKeyBytes)
|
||||
rr.Report.ReportData = expectedReportData[:]
|
||||
|
||||
file2, err := report.Transform(rr, "bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
file3, err := proto.Marshal(rr)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(file) < attestationReportSize {
|
||||
file = append(file, make([]byte, attestationReportSize-len(file))...)
|
||||
}
|
||||
|
||||
if len(file2) < attestationReportSize {
|
||||
file2 = append(file2, make([]byte, attestationReportSize-len(file2))...)
|
||||
}
|
||||
|
||||
if len(file3) < attestationReportSize {
|
||||
file3 = append(file3, make([]byte, attestationReportSize-len(file3))...)
|
||||
}
|
||||
|
||||
template.ExtraExtensions = []pkix.Extension{
|
||||
{
|
||||
Id: customSEVSNPExtensionOID,
|
||||
Value: file,
|
||||
},
|
||||
}
|
||||
|
||||
template2 := template
|
||||
template2.ExtraExtensions[0].Value = file2
|
||||
|
||||
template3 := template
|
||||
template3.ExtraExtensions[0].Value = file3
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
certDERBadSig, err := x509.CreateCertificate(rand.Reader, &template2, &template2, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
certDERMalformed, err := x509.CreateCertificate(rand.Reader, &template3, &template3, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
backendinfoFile, err := os.ReadFile("../../../scripts/backend_info/backend_info.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
attestationConfiguration = AttestationConfiguration{
|
||||
SNPPolicy: &check.Policy{},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
err = json.Unmarshal(backendinfoFile, &attestationConfiguration)
|
||||
require.NoError(t, err)
|
||||
|
||||
attestationConfiguration.SNPPolicy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
|
||||
attestationConfiguration.SNPPolicy.FamilyId = rr.Report.FamilyId
|
||||
attestationConfiguration.SNPPolicy.ImageId = rr.Report.ImageId
|
||||
attestationConfiguration.SNPPolicy.Measurement = rr.Report.Measurement
|
||||
attestationConfiguration.SNPPolicy.HostData = rr.Report.HostData
|
||||
attestationConfiguration.SNPPolicy.ReportIdMa = rr.Report.ReportIdMa
|
||||
attestationConfiguration.RootOfTrust.ProductLine = "Milan"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rawCerts [][]byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid certificate with attestation, validation fails on report data",
|
||||
rawCerts: [][]byte{certDER},
|
||||
err: errAttVerification,
|
||||
},
|
||||
{
|
||||
name: "Valid certificate with attestation, distorted signature",
|
||||
rawCerts: [][]byte{certDERBadSig},
|
||||
err: errAttVerification,
|
||||
},
|
||||
{
|
||||
name: "Valid certificate with attestation, malformed policy",
|
||||
rawCerts: [][]byte{certDERMalformed},
|
||||
err: errAttVerification,
|
||||
},
|
||||
{
|
||||
name: "Invalid certificate",
|
||||
rawCerts: [][]byte{[]byte("invalid cert")},
|
||||
err: errCertificateParse,
|
||||
},
|
||||
{
|
||||
name: "Certificate without custom extension",
|
||||
rawCerts: [][]byte{createCertWithoutCustomExtension(t)},
|
||||
err: errCustomExtension,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := verifyAttestationReportTLS(tt.rawCerts, nil)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIfCertificateSelfSigned(t *testing.T) {
|
||||
selfSignedCert := createSelfSignedCert(t)
|
||||
|
||||
@@ -446,78 +314,6 @@ func TestCheckIfCertificateSelfSigned(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillInAttestationLocal(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "test_home")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
cocosDir := tempDir + "/.cocos/Milan"
|
||||
err = os.MkdirAll(cocosDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
bundleContent := []byte("mock ASK ARK bundle")
|
||||
err = os.WriteFile(cocosDir+"/ask_ark.pem", bundleContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
attestationConfiguration = AttestationConfiguration{
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
SNPPolicy: &check.Policy{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Empty attestation",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Attestation with existing chain",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("existing ASK cert"),
|
||||
ArkCert: []byte("existing ARK cert"),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := fillInAttestationLocal(tt.attestation)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createCertWithoutCustomExtension(t *testing.T) []byte {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return certDER
|
||||
}
|
||||
|
||||
func createSelfSignedCert(t *testing.T) *x509.Certificate {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10,3 +10,4 @@ clap = { version = "4.0", features = ["derive"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
sev = "4.0.0"
|
||||
base64 = "0.22.1"
|
||||
|
||||
@@ -1,101 +1,28 @@
|
||||
{
|
||||
"snp_policy": {
|
||||
"policy": 1966081,
|
||||
"family_id": [
|
||||
0
|
||||
],
|
||||
"image_id": [
|
||||
0
|
||||
],
|
||||
"vmpl": {
|
||||
"value": 0
|
||||
},
|
||||
"minimum_tcb": 15063977803600887811,
|
||||
"minimum_launch_tcb": 15063977803600887811,
|
||||
"policy": {
|
||||
"policy": 196608,
|
||||
"family_id": "AAAAAAAAAAAAAAAAAAAAAA==",
|
||||
"image_id": "AAAAAAAAAAAAAAAAAAAAAA==",
|
||||
"vmpl": 0,
|
||||
"minimum_tcb": 15066229603414573059,
|
||||
"minimum_launch_tcb": 15066229603414573059,
|
||||
"require_author_key": false,
|
||||
"measurement": [
|
||||
0
|
||||
],
|
||||
"host_data": [
|
||||
0
|
||||
],
|
||||
"report_id_ma": [
|
||||
0
|
||||
],
|
||||
"chip_id": [
|
||||
26,
|
||||
177,
|
||||
106,
|
||||
181,
|
||||
15,
|
||||
165,
|
||||
174,
|
||||
66,
|
||||
236,
|
||||
140,
|
||||
27,
|
||||
37,
|
||||
187,
|
||||
218,
|
||||
92,
|
||||
11,
|
||||
165,
|
||||
234,
|
||||
146,
|
||||
187,
|
||||
69,
|
||||
89,
|
||||
141,
|
||||
64,
|
||||
172,
|
||||
132,
|
||||
62,
|
||||
35,
|
||||
136,
|
||||
46,
|
||||
129,
|
||||
2,
|
||||
44,
|
||||
188,
|
||||
33,
|
||||
180,
|
||||
169,
|
||||
233,
|
||||
18,
|
||||
188,
|
||||
75,
|
||||
68,
|
||||
224,
|
||||
255,
|
||||
210,
|
||||
45,
|
||||
34,
|
||||
122,
|
||||
152,
|
||||
115,
|
||||
105,
|
||||
58,
|
||||
70,
|
||||
52,
|
||||
48,
|
||||
121,
|
||||
198,
|
||||
166,
|
||||
252,
|
||||
245,
|
||||
58,
|
||||
69,
|
||||
126,
|
||||
147
|
||||
],
|
||||
"minimum_build": 7,
|
||||
"measurement": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
|
||||
"host_data": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
|
||||
"report_id_ma": "//////////////////////////////////////////8=",
|
||||
"chip_id": "GrFqtQ+lrkLsjBslu9pcC6XqkrtFWY1ArIQ+I4gugQIsvCG0qekSvEtE4P/SLSJ6mHNpOkY0MHnGpvz1OkV+kw==",
|
||||
"minimum_build": 21,
|
||||
"minimum_version": "1.55",
|
||||
"permit_provisional_firmware": false,
|
||||
"require_id_block": false
|
||||
"require_id_block": false,
|
||||
"product": {
|
||||
"name": 1
|
||||
}
|
||||
},
|
||||
"root_of_trust": {
|
||||
"product": "Milan",
|
||||
"check_crl": true,
|
||||
"disallow_network": false
|
||||
"disallow_network": false,
|
||||
"product_line": "Milan"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
use base64::prelude::*;
|
||||
use clap::{value_parser, Arg, Command};
|
||||
use serde::Serialize;
|
||||
use sev::firmware::host::*;
|
||||
@@ -23,24 +24,19 @@ struct SevProduct {
|
||||
name: i32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Vmpl {
|
||||
value: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SnpPolicy {
|
||||
policy: u64,
|
||||
family_id: Vec<u8>,
|
||||
image_id: Vec<u8>,
|
||||
vmpl: Vmpl,
|
||||
family_id: String,
|
||||
image_id: String,
|
||||
vmpl: u32,
|
||||
minimum_tcb: u64,
|
||||
minimum_launch_tcb: u64,
|
||||
require_author_key: bool,
|
||||
measurement: Vec<u8>,
|
||||
host_data: Vec<u8>,
|
||||
report_id_ma: Vec<u8>,
|
||||
chip_id: Vec<u8>,
|
||||
measurement: String,
|
||||
host_data: String,
|
||||
report_id_ma: String,
|
||||
chip_id: String,
|
||||
minimum_build: u32,
|
||||
minimum_version: String,
|
||||
permit_provisional_firmware: bool,
|
||||
@@ -58,7 +54,7 @@ struct RootOfTrust {
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Computation {
|
||||
snp_policy: SnpPolicy,
|
||||
policy: SnpPolicy,
|
||||
root_of_trust: RootOfTrust,
|
||||
}
|
||||
|
||||
@@ -125,24 +121,24 @@ fn main() {
|
||||
let status: SnpPlatformStatus = firmware.snp_platform_status().unwrap();
|
||||
|
||||
let policy: u64 = *matches.get_one::<u64>("policy").unwrap();
|
||||
let family_id = vec![0; 16];
|
||||
let image_id = vec![0; 16];
|
||||
let vmpl = Vmpl { value: 0 };
|
||||
let family_id = BASE64_STANDARD.encode(vec![0; 16]);
|
||||
let image_id = BASE64_STANDARD.encode(vec![0; 16]);
|
||||
let vmpl = 0;
|
||||
let minimum_tcb = get_uint64_from_tcb(&status.platform_tcb_version);
|
||||
let minimum_launch_tcb = get_uint64_from_tcb(&status.platform_tcb_version);
|
||||
let require_author_key = false;
|
||||
let measurement = vec![0];
|
||||
let host_data = vec![0];
|
||||
let report_id_ma = vec![0xFF; 32];
|
||||
let measurement = BASE64_STANDARD.encode(vec![0; 48]);
|
||||
let host_data = BASE64_STANDARD.encode(vec![0; 32]);
|
||||
let report_id_ma = BASE64_STANDARD.encode(vec![0xFF; 32]);
|
||||
let cpu_id: Identifier = firmware.get_identifier().unwrap();
|
||||
let chip_id: Vec<u8> = cpu_id.0;
|
||||
let chip_id: String = BASE64_STANDARD.encode(cpu_id.0);
|
||||
let minimum_build = status.build_id;
|
||||
let minimum_version = status.version.to_string();
|
||||
let permit_provisional_firmware = false;
|
||||
let require_id_block = false;
|
||||
let product = sev_product(get_sev_snp_processor());
|
||||
|
||||
let snp_policy = SnpPolicy {
|
||||
let policy = SnpPolicy {
|
||||
policy,
|
||||
family_id,
|
||||
image_id,
|
||||
@@ -169,7 +165,7 @@ fn main() {
|
||||
};
|
||||
|
||||
let computation = Computation {
|
||||
snp_policy,
|
||||
policy,
|
||||
root_of_trust,
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user