mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-326 - Add vTPM support to CoCoS (#376)
* manager, cli and agent vtpm support * rebase and changed atls for vtpm * deleted unused code * changed chekproto.yaml script so it find the manager proto file correctly * fixe manager proto version * fix agent tests * fix server agent test * fix attestation test * fix attestation test gofumpt * created dummy RWC for TPM * fix comment * add default PCR values * rebase main * fix rust ci and missing header * changed embedded attestation to VMPL 2 * fix unused impot * fix pkg test * address attestation type * fix agent attestation test * add prc15 check * fix comments * fix cli tests * add doc * add mock for LeveledQuoteProvider when SEV-SNP device is not found Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix manager reading attestation policy * refactor PCR value checks and update attestation policy values Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests for sev and grpc --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com> Co-authored-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
fa26573643
commit
67f939fc66
+24
-20
@@ -20,8 +20,8 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -51,8 +51,8 @@ var (
|
||||
errConnCreate = errors.New("could not create connection")
|
||||
)
|
||||
|
||||
type ValidationVerification func(data1, data2 []byte) error
|
||||
type FetchAttestation func(data1 []byte) ([]byte, error)
|
||||
type ValidationVerification func(data1, data2, data3, data4 []byte) error
|
||||
type FetchAttestation func(data1, data2, data3 []byte) ([]byte, error)
|
||||
|
||||
func registerFetchAttestation(callback FetchAttestation) uintptr {
|
||||
handle := cgo.NewHandle(callback)
|
||||
@@ -70,7 +70,7 @@ func validationVerificationCallback(teeType C.int) uintptr {
|
||||
case NoTee:
|
||||
return uintptr(0)
|
||||
case AmdSevSnp:
|
||||
return registerValidationVerification(quoteprovider.VerifyAttestationReportTLS)
|
||||
return registerValidationVerification(vtpm.VTPMVerify)
|
||||
default:
|
||||
return uintptr(0)
|
||||
}
|
||||
@@ -82,22 +82,24 @@ func fetchAttestationCallback(teeType C.int) uintptr {
|
||||
case NoTee:
|
||||
return uintptr(0)
|
||||
case AmdSevSnp:
|
||||
return registerFetchAttestation(quoteprovider.FetchAttestation)
|
||||
return registerFetchAttestation(vtpm.FetchATLSQuote)
|
||||
default:
|
||||
return uintptr(0)
|
||||
}
|
||||
}
|
||||
|
||||
//export callVerificationValidationCallback
|
||||
func callVerificationValidationCallback(callbackHandle uintptr, attReport *C.uchar, attReportSize C.int, repData *C.uchar) C.int {
|
||||
func callVerificationValidationCallback(callbackHandle uintptr, pubKey *C.uchar, pubKeyLen C.int, quote *C.uchar, quoteSize C.int, teeNonce *C.uchar, nonce *C.uchar) C.int {
|
||||
handle := cgo.Handle(callbackHandle)
|
||||
defer handle.Delete()
|
||||
|
||||
callback := handle.Value().(ValidationVerification)
|
||||
attestationReport := C.GoBytes(unsafe.Pointer(attReport), attReportSize)
|
||||
reportData := C.GoBytes(unsafe.Pointer(repData), agent.ReportDataSize)
|
||||
pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen)
|
||||
attestationReport := C.GoBytes(unsafe.Pointer(quote), quoteSize)
|
||||
teeData := C.GoBytes(unsafe.Pointer(teeNonce), quoteprovider.Nonce)
|
||||
nonceData := C.GoBytes(unsafe.Pointer(nonce), vtpm.Nonce)
|
||||
|
||||
err := callback(attestationReport, reportData)
|
||||
err := callback(attestationReport, pubKeyCert, teeData, nonceData)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "callback failed %v", err)
|
||||
return C.int(-1)
|
||||
@@ -107,20 +109,22 @@ func callVerificationValidationCallback(callbackHandle uintptr, attReport *C.uch
|
||||
}
|
||||
|
||||
//export callFetchAttestationCallback
|
||||
func callFetchAttestationCallback(callbackHandle uintptr, reportDataByte *C.uchar, outlen *C.int) *C.uchar {
|
||||
func callFetchAttestationCallback(callbackHandle uintptr, pubKey *C.uchar, pubKeyLen C.int, teeNonceByte *C.uchar, vTPMNonceByte *C.uchar, outlen *C.ulong) *C.uchar {
|
||||
handle := cgo.Handle(callbackHandle)
|
||||
defer handle.Delete()
|
||||
|
||||
callback := handle.Value().(FetchAttestation)
|
||||
reportData := C.GoBytes(unsafe.Pointer(reportDataByte), agent.ReportDataSize)
|
||||
pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen)
|
||||
teeNonceData := C.GoBytes(unsafe.Pointer(teeNonceByte), quoteprovider.Nonce)
|
||||
vTPMNonce := C.GoBytes(unsafe.Pointer(vTPMNonceByte), vtpm.Nonce)
|
||||
|
||||
quote, err := callback(reportData)
|
||||
quote, err := callback(pubKeyCert, teeNonceData, vTPMNonce)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "attestation callback returned nil")
|
||||
return nil
|
||||
}
|
||||
|
||||
*outlen = C.int(len(quote))
|
||||
*outlen = C.ulong(len(quote))
|
||||
resultC := C.malloc(C.size_t(len(quote)))
|
||||
if resultC == nil {
|
||||
fmt.Fprintf(os.Stderr, "could not allocate memory for fetch attestation callback")
|
||||
@@ -232,23 +236,23 @@ func (c *ATLSConn) Read(b []byte) (int, error) {
|
||||
case noError:
|
||||
return n, nil // no error.
|
||||
case errorZeroReturn:
|
||||
fmt.Fprintf(os.Stdout, "Connection closed by peer")
|
||||
fmt.Fprintf(os.Stdout, "Connection closed by peer\n")
|
||||
return 0, io.EOF // connection closed.
|
||||
case errorWantRead:
|
||||
fmt.Fprintf(os.Stderr, "Operation read incomplete, retry later")
|
||||
fmt.Fprintf(os.Stderr, "Operation read incomplete, retry later\n")
|
||||
return 0, nil // non-fatal, just retry later.
|
||||
case errorWantWrite:
|
||||
fmt.Fprintf(os.Stderr, "Operation write incomplete, retry later")
|
||||
fmt.Fprintf(os.Stderr, "Operation write incomplete, retry later\n")
|
||||
return 0, nil // non-fatal, just retry later.
|
||||
case errorSyscall:
|
||||
fmt.Fprintf(os.Stderr, "I/O error")
|
||||
fmt.Fprintf(os.Stderr, "I/O error\n")
|
||||
return 0, syscall.ECONNRESET // return connection reset error.
|
||||
case errorSsl:
|
||||
fmt.Fprintf(os.Stderr, "I/O error")
|
||||
fmt.Fprintf(os.Stderr, "I/O error\n")
|
||||
return 0, syscall.ECONNRESET // return connection reset error.
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "SSL error occurred: %d\n", errCode)
|
||||
return 0, fmt.Errorf("SSL error")
|
||||
return 0, fmt.Errorf("SSL error\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -297,7 +301,7 @@ func (c *ATLSConn) Close() error {
|
||||
return errTLSConn
|
||||
} else if int(ret) == 1 {
|
||||
c.tlsConn = nil
|
||||
break;
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+74
-97
@@ -7,30 +7,27 @@
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
extern int callVerificationValidationCallback(uintptr_t callbackHandle, const u_char* attReport, int attReportSize, const u_char* repData);
|
||||
extern u_char* callFetchAttestationCallback(uintptr_t callbackHandle, const u_char* reportDataByte, int* outlen);
|
||||
extern int callVerificationValidationCallback(uintptr_t callbackHandle, const u_char* pubKey, int pubKeyLen, const u_char* quote, int quoteSize, const u_char* teeNonce, const u_char* nonce);
|
||||
extern u_char* callFetchAttestationCallback(uintptr_t callbackHandle, const u_char* pubKey, int pubKeyLen, const u_char* teeNonceByte, const u_char* vTPMNonceByte, unsigned long* outlen);
|
||||
extern uintptr_t validationVerificationCallback(int teeType);
|
||||
extern uintptr_t fetchAttestationCallback(int teeType);
|
||||
|
||||
int triggerVerificationValidationCallback(uintptr_t callbackHandle, u_char *attestationReport, int reportSize, u_char *reportData) {
|
||||
if (attestationReport == NULL || reportData == NULL) {
|
||||
fprintf(stderr, "attestation data and report data cannot be NULL\n");
|
||||
int triggerVerificationValidationCallback(uintptr_t callbackHandle, u_char* pub_key, int pub_key_len, u_char *quote, int quote_size, u_char *tee_nonce, u_char *vtpm_nonce) {
|
||||
if (quote == NULL || vtpm_nonce == NULL || tee_nonce == NULL || pub_key == NULL) {
|
||||
fprintf(stderr, "attestation and noce and public key cannot be NULL\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
return callVerificationValidationCallback(callbackHandle, attestationReport, reportSize, reportData);
|
||||
return callVerificationValidationCallback(callbackHandle, pub_key, pub_key_len, quote, quote_size, tee_nonce, vtpm_nonce);
|
||||
}
|
||||
|
||||
u_char* triggerFetchAttestationCallback(uintptr_t callbackHandle, char *reportData) {
|
||||
int outlen = REPORT_DATA_SIZE;
|
||||
|
||||
if(reportData == NULL) {
|
||||
u_char* triggerFetchAttestationCallback(uintptr_t callback_handle, u_char* pub_key, int pub_key_len, char *tee_nonce, char *vtpm_nonce, unsigned long *outlen) {
|
||||
if(tee_nonce == NULL || vtpm_nonce == NULL) {
|
||||
fprintf(stderr, "Report data cannot be NULL");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return callFetchAttestationCallback(callbackHandle, reportData, &outlen);
|
||||
return callFetchAttestationCallback(callback_handle, pub_key, pub_key_len, tee_nonce, vtpm_nonce, outlen);
|
||||
}
|
||||
|
||||
int check_sev_snp() {
|
||||
@@ -47,46 +44,6 @@ int check_sev_snp() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int compute_sha256_of_public_key_nonce(X509 *cert, u_char *nonce, u_char *hash) {
|
||||
EVP_PKEY *pkey = NULL;
|
||||
u_char *pubkey_buf = NULL;
|
||||
u_char *concatinated = NULL;
|
||||
int pubkey_len = 0;
|
||||
int totla_len = 0;
|
||||
|
||||
pkey = X509_get_pubkey(cert);
|
||||
if (pkey == NULL) {
|
||||
fprintf(stderr, "Failed to extract public key from certificate\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
|
||||
if (pubkey_len <= 0) {
|
||||
fprintf(stderr, "Failed to convert public key to DER format\n");
|
||||
EVP_PKEY_free(pkey);
|
||||
return -1;
|
||||
}
|
||||
|
||||
totla_len = pubkey_len + CLIENT_RANDOM_SIZE;
|
||||
concatinated = (u_char*)malloc(totla_len);
|
||||
if (concatinated == NULL) {
|
||||
perror("failed to allocate memory");
|
||||
return -1;
|
||||
}
|
||||
memcpy(concatinated, nonce, CLIENT_RANDOM_SIZE);
|
||||
memcpy(concatinated + CLIENT_RANDOM_SIZE, pubkey_buf, pubkey_len);
|
||||
|
||||
// Compute the SHA-512 hash of the DER-encoded public key and the random nonce
|
||||
SHA512(concatinated, totla_len, hash);
|
||||
|
||||
// Clean up
|
||||
EVP_PKEY_free(pkey);
|
||||
OPENSSL_free(pubkey_buf);
|
||||
free(concatinated);
|
||||
|
||||
return 0; // Success
|
||||
}
|
||||
|
||||
/*
|
||||
Evidence request extension
|
||||
- Contains a random nonce that goes into the attestation report
|
||||
@@ -121,9 +78,14 @@ int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
}
|
||||
|
||||
if (ext_data != NULL) {
|
||||
if (RAND_bytes(ext_data->er.data, CLIENT_RANDOM_SIZE) != 1) {
|
||||
perror("could not generate random bytes, will use SSL client random");
|
||||
SSL_get_client_random(s, ext_data->er.data, CLIENT_RANDOM_SIZE);
|
||||
if (RAND_bytes(ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE) != 1) {
|
||||
perror("could not generate random bytes for vtpm nonce, will use SSL client random");
|
||||
SSL_get_client_random(s, ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE);
|
||||
}
|
||||
|
||||
if (RAND_bytes(ext_data->er.tee_nonce, REPORT_DATA_SIZE) != 1) {
|
||||
perror("could not generate random bytes for tee nonce, will use SSL client random");
|
||||
SSL_get_client_random(s, ext_data->er.tee_nonce, REPORT_DATA_SIZE);
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "add_arg is NULL\n");
|
||||
@@ -132,7 +94,8 @@ int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
return -1;
|
||||
}
|
||||
|
||||
memcpy(er->data, ext_data->er.data, CLIENT_RANDOM_SIZE);
|
||||
memcpy(er->vtpm_nonce, ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE);
|
||||
memcpy(er->tee_nonce, ext_data->er.tee_nonce, REPORT_DATA_SIZE);
|
||||
er->tee_type = AMD_TEE;
|
||||
ext_data->er.tee_type = AMD_TEE;
|
||||
|
||||
@@ -201,7 +164,8 @@ int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
evidence_request *er = (evidence_request*)in;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
memcpy(ext_data->er.data, er->data, CLIENT_RANDOM_SIZE);
|
||||
memcpy(ext_data->er.vtpm_nonce, er->vtpm_nonce, CLIENT_RANDOM_SIZE);
|
||||
memcpy(ext_data->er.tee_nonce, er->tee_nonce, REPORT_DATA_SIZE);
|
||||
ext_data->er.tee_type = er->tee_type;
|
||||
} else {
|
||||
fprintf(stderr, "parse_arg is NULL\n");
|
||||
@@ -238,7 +202,7 @@ int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
/*
|
||||
Attestation Certificate extension
|
||||
- Contains the attestation report
|
||||
- The attestation report contains the hash of the nonce and the Public Key of the x.509 Agent certificate
|
||||
- The attestation report contains the hash of the nonce, the Public Key of the x.509 Agent certificate, and the vTPM AK
|
||||
*/
|
||||
void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type,
|
||||
unsigned int context,
|
||||
@@ -263,40 +227,46 @@ int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type,
|
||||
{
|
||||
tls_extension_data *ext_data = (tls_extension_data*)add_arg;
|
||||
if (ext_data != NULL) {
|
||||
u_char *attestation_report;
|
||||
u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char));
|
||||
|
||||
if (hash == NULL) {
|
||||
perror("could not allocate memory");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
u_char *quote;
|
||||
size_t len = 0;
|
||||
EVP_PKEY *pkey = NULL;
|
||||
u_char *pubkey_buf = NULL;
|
||||
int pubkey_len = 0;
|
||||
|
||||
|
||||
if (x != NULL) {
|
||||
int ret = compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash);
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "error while calculating hash\n");
|
||||
free(hash);
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
pkey = X509_get_pubkey(x);
|
||||
if (pkey == NULL) {
|
||||
fprintf(stderr, "Failed to extract public key from certificate\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
|
||||
if (pubkey_len <= 0) {
|
||||
fprintf(stderr, "Failed to convert public key to DER format\n");
|
||||
EVP_PKEY_free(pkey);
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "agent certificate must be used for aTLS\n");
|
||||
free(hash);
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
return -1;
|
||||
}
|
||||
|
||||
attestation_report = triggerFetchAttestationCallback(ext_data->fetch_attestation_handler, hash);
|
||||
if (attestation_report == NULL) {
|
||||
quote = triggerFetchAttestationCallback(ext_data->fetch_attestation_handler, pubkey_buf, pubkey_len, ext_data->er.tee_nonce, ext_data->er.vtpm_nonce, &len);
|
||||
if (quote == NULL) {
|
||||
fprintf(stderr, "attestation report is NULL\n");
|
||||
*al = SSL_AD_INTERNAL_ERROR;
|
||||
EVP_PKEY_free(pkey);
|
||||
OPENSSL_free(pubkey_buf);
|
||||
return -1;
|
||||
}
|
||||
free(hash);
|
||||
|
||||
*out = attestation_report;
|
||||
*outlen = ATTESTATION_REPORT_SIZE;
|
||||
EVP_PKEY_free(pkey);
|
||||
OPENSSL_free(pubkey_buf);
|
||||
|
||||
*out = quote;
|
||||
*outlen = len;
|
||||
return 1;
|
||||
} else {
|
||||
fprintf(stderr, "add_arg is NULL\n");
|
||||
@@ -329,34 +299,41 @@ int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type,
|
||||
tls_extension_data *ext_data = (tls_extension_data*)parse_arg;
|
||||
|
||||
if (ext_data != NULL) {
|
||||
char *attestation_report = (char*)malloc(ATTESTATION_REPORT_SIZE*sizeof(char));
|
||||
u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char));
|
||||
char *quote = (char*)malloc(inlen*sizeof(char));
|
||||
EVP_PKEY *pkey = NULL;
|
||||
u_char *pubkey_buf = NULL;
|
||||
int pubkey_len = 0;
|
||||
int res = 0;
|
||||
|
||||
if (hash == NULL || attestation_report == NULL) {
|
||||
if (quote == NULL) {
|
||||
perror("could not allocate memory");
|
||||
|
||||
if (hash != NULL) free(hash);
|
||||
if (attestation_report != NULL) free(attestation_report);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash) != 0) {
|
||||
fprintf(stderr, "calculating hash failed\n");
|
||||
free(attestation_report);
|
||||
free(hash);
|
||||
return 0;
|
||||
pkey = X509_get_pubkey(x);
|
||||
if (pkey == NULL) {
|
||||
fprintf(stderr, "Failed to extract public key from certificate\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf);
|
||||
if (pubkey_len <= 0) {
|
||||
fprintf(stderr, "Failed to convert public key to DER format\n");
|
||||
EVP_PKEY_free(pkey);
|
||||
return -1;
|
||||
}
|
||||
memcpy(quote, in, inlen);
|
||||
|
||||
memcpy(attestation_report, in, inlen);
|
||||
|
||||
res = triggerVerificationValidationCallback(ext_data->verification_validation_handler,
|
||||
attestation_report,
|
||||
ATTESTATION_REPORT_SIZE,
|
||||
hash);
|
||||
free(attestation_report);
|
||||
free(hash);
|
||||
res = triggerVerificationValidationCallback(ext_data->verification_validation_handler,
|
||||
pubkey_buf,
|
||||
pubkey_len,
|
||||
quote,
|
||||
inlen,
|
||||
(u_char*)&ext_data->er.tee_nonce,
|
||||
(u_char*)&ext_data->er.vtpm_nonce);
|
||||
free(quote);
|
||||
EVP_PKEY_free(pkey);
|
||||
OPENSSL_free(pubkey_buf);
|
||||
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "verification and validation failed, aborting connection\n");
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
#define EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE 65
|
||||
#define ATTESTATION_CERTIFICATE_EXTENSION_TYPE 66
|
||||
#define ATTESTATION_REPORT_SIZE 0x4A0
|
||||
#define REPORT_DATA_SIZE 64
|
||||
#define CLIENT_RANDOM_SIZE 32
|
||||
#define TLS_CLIENT_CTX 0
|
||||
@@ -19,7 +18,8 @@
|
||||
typedef struct evidence_request
|
||||
{
|
||||
int tee_type;
|
||||
char data[CLIENT_RANDOM_SIZE];
|
||||
char vtpm_nonce[CLIENT_RANDOM_SIZE];
|
||||
char tee_nonce[REPORT_DATA_SIZE];
|
||||
} evidence_request;
|
||||
|
||||
typedef struct tls_extension_data
|
||||
@@ -32,7 +32,7 @@ typedef struct tls_extension_data
|
||||
typedef struct tls_server_connection
|
||||
{
|
||||
int server_fd;
|
||||
char* cert;
|
||||
char* cert;
|
||||
int cert_len;
|
||||
char* key;
|
||||
int key_len;
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
type AttestationType int32
|
||||
|
||||
const (
|
||||
SNP AttestationType = iota
|
||||
VTPM
|
||||
SNPvTPM
|
||||
)
|
||||
|
||||
var (
|
||||
AttestationPolicy = Config{SnpCheck: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &PcrConfig{}}
|
||||
ErrAttestationPolicyOpen = errors.New("failed to open Attestation Policy file")
|
||||
ErrAttestationPolicyDecode = errors.New("failed to decode Attestation Policy file")
|
||||
ErrAttestationPolicyMissing = errors.New("failed due to missing Attestation Policy file")
|
||||
)
|
||||
|
||||
type PcrValues struct {
|
||||
Sha256 map[string]string `json:"sha256"`
|
||||
Sha384 map[string]string `json:"sha384"`
|
||||
}
|
||||
|
||||
type PcrConfig struct {
|
||||
PCRValues PcrValues `json:"pcr_values"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
SnpCheck *check.Config
|
||||
PcrConfig *PcrConfig
|
||||
}
|
||||
|
||||
func ReadAttestationPolicy(policyPath string, attestationConfiguration *Config) error {
|
||||
if policyPath != "" {
|
||||
policyData, err := os.ReadFile(policyPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(ErrAttestationPolicyOpen, err)
|
||||
}
|
||||
|
||||
return ReadAttestationPolicyFromByte(policyData, attestationConfiguration)
|
||||
}
|
||||
|
||||
return ErrAttestationPolicyMissing
|
||||
}
|
||||
|
||||
func ReadAttestationPolicyFromByte(policyData []byte, attestationConfiguration *Config) error {
|
||||
unmarshalOptions := protojson.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true}
|
||||
|
||||
if err := unmarshalOptions.Unmarshal(policyData, attestationConfiguration.SnpCheck); err != nil {
|
||||
return errors.Wrap(ErrAttestationPolicyDecode, err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(policyData, attestationConfiguration.PcrConfig); err != nil {
|
||||
return errors.Wrap(ErrAttestationPolicyDecode, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -8,26 +8,24 @@ package quoteprovider
|
||||
|
||||
import (
|
||||
"github.com/google/go-sev-guest/client"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
pb "github.com/google/go-sev-guest/proto/sevsnp"
|
||||
cocosai "github.com/ultravioletrs/cocos"
|
||||
)
|
||||
|
||||
var (
|
||||
AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
)
|
||||
const Nonce = 64
|
||||
|
||||
var _ client.QuoteProvider = (*embeddedQuoteProvider)(nil)
|
||||
var _ client.LeveledQuoteProvider = (*embeddedQuoteProvider)(nil)
|
||||
|
||||
type embeddedQuoteProvider struct {
|
||||
}
|
||||
|
||||
func GetQuoteProvider() (client.QuoteProvider, error) {
|
||||
func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
|
||||
return &embeddedQuoteProvider{}, nil
|
||||
}
|
||||
|
||||
// GetQuote returns the SEV quote for the given report data.
|
||||
func (e *embeddedQuoteProvider) GetRawQuote(reportData [64]byte) ([]byte, error) {
|
||||
// GetRawQuoteAtLevel returns the SEV quote for the given report data and VMPL.
|
||||
func (e *embeddedQuoteProvider) GetRawQuoteAtLevel(reportData [64]byte, vmpl uint) ([]byte, error) {
|
||||
return cocosai.EmbeddedAttestation, nil
|
||||
}
|
||||
|
||||
@@ -46,6 +44,6 @@ func FetchAttestation(reportDataSlice []byte) ([]byte, error) {
|
||||
return cocosai.EmbeddedAttestation, nil
|
||||
}
|
||||
|
||||
func VerifyAttestationReportTLS(attestationBytes []byte, reportData []byte) error {
|
||||
func VerifyAttestationReportTLS(attestation *sevsnp.Attestation, reportData []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,42 +10,42 @@ import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// QuoteProvider is an autogenerated mock type for the QuoteProvider type
|
||||
type QuoteProvider struct {
|
||||
// LeveledQuoteProvider is an autogenerated mock type for the LeveledQuoteProvider type
|
||||
type LeveledQuoteProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type QuoteProvider_Expecter struct {
|
||||
type LeveledQuoteProvider_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *QuoteProvider) EXPECT() *QuoteProvider_Expecter {
|
||||
return &QuoteProvider_Expecter{mock: &_m.Mock}
|
||||
func (_m *LeveledQuoteProvider) EXPECT() *LeveledQuoteProvider_Expecter {
|
||||
return &LeveledQuoteProvider_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// GetRawQuote provides a mock function with given fields: reportData
|
||||
func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) {
|
||||
ret := _m.Called(reportData)
|
||||
// GetRawQuoteAtLevel provides a mock function with given fields: reportData, vmpl
|
||||
func (_m *LeveledQuoteProvider) GetRawQuoteAtLevel(reportData [64]byte, vmpl uint) ([]uint8, error) {
|
||||
ret := _m.Called(reportData, vmpl)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetRawQuote")
|
||||
panic("no return value specified for GetRawQuoteAtLevel")
|
||||
}
|
||||
|
||||
var r0 []uint8
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func([64]byte) ([]uint8, error)); ok {
|
||||
return rf(reportData)
|
||||
if rf, ok := ret.Get(0).(func([64]byte, uint) ([]uint8, error)); ok {
|
||||
return rf(reportData, vmpl)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func([64]byte) []uint8); ok {
|
||||
r0 = rf(reportData)
|
||||
if rf, ok := ret.Get(0).(func([64]byte, uint) []uint8); ok {
|
||||
r0 = rf(reportData, vmpl)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]uint8)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func([64]byte) error); ok {
|
||||
r1 = rf(reportData)
|
||||
if rf, ok := ret.Get(1).(func([64]byte, uint) error); ok {
|
||||
r1 = rf(reportData, vmpl)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@@ -53,36 +53,37 @@ func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) {
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// QuoteProvider_GetRawQuote_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRawQuote'
|
||||
type QuoteProvider_GetRawQuote_Call struct {
|
||||
// LeveledQuoteProvider_GetRawQuoteAtLevel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRawQuoteAtLevel'
|
||||
type LeveledQuoteProvider_GetRawQuoteAtLevel_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetRawQuote is a helper method to define mock.On call
|
||||
// GetRawQuoteAtLevel is a helper method to define mock.On call
|
||||
// - reportData [64]byte
|
||||
func (_e *QuoteProvider_Expecter) GetRawQuote(reportData interface{}) *QuoteProvider_GetRawQuote_Call {
|
||||
return &QuoteProvider_GetRawQuote_Call{Call: _e.mock.On("GetRawQuote", reportData)}
|
||||
// - vmpl uint
|
||||
func (_e *LeveledQuoteProvider_Expecter) GetRawQuoteAtLevel(reportData interface{}, vmpl interface{}) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call {
|
||||
return &LeveledQuoteProvider_GetRawQuoteAtLevel_Call{Call: _e.mock.On("GetRawQuoteAtLevel", reportData, vmpl)}
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_GetRawQuote_Call) Run(run func(reportData [64]byte)) *QuoteProvider_GetRawQuote_Call {
|
||||
func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) Run(run func(reportData [64]byte, vmpl uint)) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([64]byte))
|
||||
run(args[0].([64]byte), args[1].(uint))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_GetRawQuote_Call) Return(_a0 []uint8, _a1 error) *QuoteProvider_GetRawQuote_Call {
|
||||
func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) Return(_a0 []uint8, _a1 error) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_GetRawQuote_Call) RunAndReturn(run func([64]byte) ([]uint8, error)) *QuoteProvider_GetRawQuote_Call {
|
||||
func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) RunAndReturn(run func([64]byte, uint) ([]uint8, error)) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// IsSupported provides a mock function with given fields:
|
||||
func (_m *QuoteProvider) IsSupported() bool {
|
||||
func (_m *LeveledQuoteProvider) IsSupported() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
@@ -99,35 +100,35 @@ func (_m *QuoteProvider) IsSupported() bool {
|
||||
return r0
|
||||
}
|
||||
|
||||
// QuoteProvider_IsSupported_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsSupported'
|
||||
type QuoteProvider_IsSupported_Call struct {
|
||||
// LeveledQuoteProvider_IsSupported_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsSupported'
|
||||
type LeveledQuoteProvider_IsSupported_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// IsSupported is a helper method to define mock.On call
|
||||
func (_e *QuoteProvider_Expecter) IsSupported() *QuoteProvider_IsSupported_Call {
|
||||
return &QuoteProvider_IsSupported_Call{Call: _e.mock.On("IsSupported")}
|
||||
func (_e *LeveledQuoteProvider_Expecter) IsSupported() *LeveledQuoteProvider_IsSupported_Call {
|
||||
return &LeveledQuoteProvider_IsSupported_Call{Call: _e.mock.On("IsSupported")}
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_IsSupported_Call) Run(run func()) *QuoteProvider_IsSupported_Call {
|
||||
func (_c *LeveledQuoteProvider_IsSupported_Call) Run(run func()) *LeveledQuoteProvider_IsSupported_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_IsSupported_Call) Return(_a0 bool) *QuoteProvider_IsSupported_Call {
|
||||
func (_c *LeveledQuoteProvider_IsSupported_Call) Return(_a0 bool) *LeveledQuoteProvider_IsSupported_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_IsSupported_Call) RunAndReturn(run func() bool) *QuoteProvider_IsSupported_Call {
|
||||
func (_c *LeveledQuoteProvider_IsSupported_Call) RunAndReturn(run func() bool) *LeveledQuoteProvider_IsSupported_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Product provides a mock function with given fields:
|
||||
func (_m *QuoteProvider) Product() *sevsnp.SevProduct {
|
||||
func (_m *LeveledQuoteProvider) Product() *sevsnp.SevProduct {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
@@ -146,40 +147,40 @@ func (_m *QuoteProvider) Product() *sevsnp.SevProduct {
|
||||
return r0
|
||||
}
|
||||
|
||||
// QuoteProvider_Product_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Product'
|
||||
type QuoteProvider_Product_Call struct {
|
||||
// LeveledQuoteProvider_Product_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Product'
|
||||
type LeveledQuoteProvider_Product_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Product is a helper method to define mock.On call
|
||||
func (_e *QuoteProvider_Expecter) Product() *QuoteProvider_Product_Call {
|
||||
return &QuoteProvider_Product_Call{Call: _e.mock.On("Product")}
|
||||
func (_e *LeveledQuoteProvider_Expecter) Product() *LeveledQuoteProvider_Product_Call {
|
||||
return &LeveledQuoteProvider_Product_Call{Call: _e.mock.On("Product")}
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_Product_Call) Run(run func()) *QuoteProvider_Product_Call {
|
||||
func (_c *LeveledQuoteProvider_Product_Call) Run(run func()) *LeveledQuoteProvider_Product_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_Product_Call) Return(_a0 *sevsnp.SevProduct) *QuoteProvider_Product_Call {
|
||||
func (_c *LeveledQuoteProvider_Product_Call) Return(_a0 *sevsnp.SevProduct) *LeveledQuoteProvider_Product_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *QuoteProvider_Product_Call) RunAndReturn(run func() *sevsnp.SevProduct) *QuoteProvider_Product_Call {
|
||||
func (_c *LeveledQuoteProvider_Product_Call) RunAndReturn(run func() *sevsnp.SevProduct) *LeveledQuoteProvider_Product_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewQuoteProvider creates a new instance of QuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// NewLeveledQuoteProvider creates a new instance of LeveledQuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewQuoteProvider(t interface {
|
||||
func NewLeveledQuoteProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *QuoteProvider {
|
||||
mock := &QuoteProvider{}
|
||||
}) *LeveledQuoteProvider {
|
||||
mock := &LeveledQuoteProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/client"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
@@ -22,6 +21,7 @@ import (
|
||||
"github.com/google/go-sev-guest/verify"
|
||||
"github.com/google/go-sev-guest/verify/trust"
|
||||
"github.com/google/logger"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
@@ -29,20 +29,19 @@ const (
|
||||
cocosDirectory = ".cocos"
|
||||
caBundleName = "ask_ark.pem"
|
||||
attestationReportSize = 0x4A0
|
||||
reportDataSize = 64
|
||||
Nonce = 64
|
||||
sevProductNameMilan = "Milan"
|
||||
sevProductNameGenoa = "Genoa"
|
||||
sevVMPL = 2
|
||||
)
|
||||
|
||||
var (
|
||||
AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
timeout = time.Minute * 2
|
||||
maxTryDelay = time.Second * 30
|
||||
timeout = time.Minute * 2
|
||||
maxTryDelay = time.Second * 30
|
||||
)
|
||||
|
||||
var (
|
||||
errProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevProductNameMilan, sevProductNameGenoa))
|
||||
errReportSize = errors.New("attestation report size mismatch")
|
||||
errAttVerification = errors.New("attestation verification failed")
|
||||
errAttValidation = errors.New("attestation validation failed")
|
||||
)
|
||||
@@ -138,38 +137,31 @@ func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetQuoteProvider() (client.QuoteProvider, error) {
|
||||
return client.GetQuoteProvider()
|
||||
func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
|
||||
return client.GetLeveledQuoteProvider()
|
||||
}
|
||||
|
||||
func VerifyAttestationReportTLS(attestationBytes []byte, reportData []byte) error {
|
||||
config, err := copyConfig(&AttConfigurationSEVSNP)
|
||||
func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte) error {
|
||||
config, err := copyConfig(config.AttestationPolicy.SnpCheck)
|
||||
if err != nil {
|
||||
return errors.Wrap(fmt.Errorf("failed to create a copy of attestation policy"), err)
|
||||
}
|
||||
|
||||
// Certificate chain is populated based on the extra data that is appended to the SEV-SNP attestation report.
|
||||
// This data is not part of the attestation report and it will be ignored.
|
||||
attestationPB.CertificateChain = nil
|
||||
config.Policy.ReportData = reportData[:]
|
||||
return VerifyAndValidate(attestationBytes, config)
|
||||
return VerifyAndValidate(attestationPB, config)
|
||||
}
|
||||
|
||||
func VerifyAndValidate(attestationReport []byte, cfg *check.Config) error {
|
||||
func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
logger.Init("", false, false, io.Discard)
|
||||
|
||||
if len(attestationReport) < attestationReportSize {
|
||||
return errReportSize
|
||||
}
|
||||
attestationBytes := attestationReport[:attestationReportSize]
|
||||
|
||||
attestationPB, err := abi.ReportCertsToProto(attestationBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert attestation bytes to struct %v", errors.Wrap(errAttVerification, err))
|
||||
}
|
||||
|
||||
if err = verifyReport(attestationPB, cfg); err != nil {
|
||||
if err := verifyReport(attestationPB, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = validateReport(attestationPB, cfg); err != nil {
|
||||
if err := validateReport(attestationPB, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -177,19 +169,19 @@ func VerifyAndValidate(attestationReport []byte, cfg *check.Config) error {
|
||||
}
|
||||
|
||||
func FetchAttestation(reportDataSlice []byte) ([]byte, error) {
|
||||
var reportData [reportDataSize]byte
|
||||
var reportData [Nonce]byte
|
||||
|
||||
qp, err := GetQuoteProvider()
|
||||
qp, err := GetLeveledQuoteProvider()
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("could not get quote provider")
|
||||
}
|
||||
|
||||
if len(reportData) > reportDataSize {
|
||||
if len(reportData) > Nonce {
|
||||
return []byte{}, fmt.Errorf("attestation report size mismatch")
|
||||
}
|
||||
copy(reportData[:], reportDataSlice)
|
||||
|
||||
rawQuote, err := qp.GetRawQuote(reportData)
|
||||
rawQuote, err := qp.GetRawQuoteAtLevel(reportData, sevVMPL)
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("failed to get raw quote")
|
||||
}
|
||||
|
||||
@@ -17,14 +17,10 @@ import (
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const (
|
||||
measurementOffset = 0x90
|
||||
signatureOffset = 0x2A0
|
||||
)
|
||||
|
||||
func TestFillInAttestationLocal(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "test_home")
|
||||
require.NoError(t, err)
|
||||
@@ -76,18 +72,18 @@ func TestFillInAttestationLocal(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportSuccess(t *testing.T) {
|
||||
file, reportData := prepareForTestVerifyAttestationReport(t)
|
||||
attestationPB, reportData := prepVerifyAttReport(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationReport []byte
|
||||
attestationReport *sevsnp.Attestation
|
||||
reportData []byte
|
||||
goodProduct int
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation, validation and verification is performed succsessfully",
|
||||
attestationReport: file,
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
goodProduct: 1,
|
||||
err: nil,
|
||||
@@ -103,20 +99,20 @@ func TestVerifyAttestationReportSuccess(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
file, reportData := prepareForTestVerifyAttestationReport(t)
|
||||
attestationPB, reportData := prepVerifyAttReport(t)
|
||||
|
||||
// Change random data so in the signature so the signature failes
|
||||
file[signatureOffset] = file[signatureOffset] ^ 0x01
|
||||
attestationPB.Report.Signature[0] = attestationPB.Report.Signature[0] ^ 0x01
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationReport []byte
|
||||
attestationReport *sevsnp.Attestation
|
||||
reportData []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation, distorted signature",
|
||||
attestationReport: file,
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: errAttVerification,
|
||||
},
|
||||
@@ -131,17 +127,17 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
|
||||
file, reportData := prepareForTestVerifyAttestationReport(t)
|
||||
attestationPB, reportData := prepVerifyAttReport(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationReport []byte
|
||||
attestationReport *sevsnp.Attestation
|
||||
reportData []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation, unknown product",
|
||||
attestationReport: file,
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: errProductLine,
|
||||
},
|
||||
@@ -149,8 +145,8 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
AttConfigurationSEVSNP.RootOfTrust.ProductLine = ""
|
||||
AttConfigurationSEVSNP.Policy.Product = nil
|
||||
config.AttestationPolicy.SnpCheck.RootOfTrust.ProductLine = ""
|
||||
config.AttestationPolicy.SnpCheck.Policy.Product = nil
|
||||
err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
@@ -158,20 +154,20 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportMalformedPolicy(t *testing.T) {
|
||||
file, reportData := prepareForTestVerifyAttestationReport(t)
|
||||
attestationPB, reportData := prepVerifyAttReport(t)
|
||||
|
||||
// Change random data in the measurement so the measurement does not match
|
||||
file[measurementOffset] = file[measurementOffset] ^ 0x01
|
||||
attestationPB.Report.Measurement[0] = attestationPB.Report.Measurement[0] ^ 0x01
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationReport []byte
|
||||
attestationReport *sevsnp.Attestation
|
||||
reportData []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation, malformed policy (measurement)",
|
||||
attestationReport: file,
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: errAttVerification,
|
||||
},
|
||||
@@ -185,32 +181,34 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func prepareForTestVerifyAttestationReport(t *testing.T) ([]byte, []byte) {
|
||||
func prepVerifyAttReport(t *testing.T) (*sevsnp.Attestation, []byte) {
|
||||
file, err := os.ReadFile("../../../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
rr, err := abi.ReportCertsToProto(file)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(file) < attestationReportSize {
|
||||
file = append(file, make([]byte, attestationReportSize-len(file))...)
|
||||
}
|
||||
|
||||
AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
rr, err := abi.ReportCertsToProto(file)
|
||||
require.NoError(t, err)
|
||||
|
||||
config.AttestationPolicy = config.Config{SnpCheck: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &config.PcrConfig{}}
|
||||
|
||||
attestationPolicyFile, err := os.ReadFile("../../../scripts/attestation_policy/attestation_policy.json")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = protojson.Unmarshal(attestationPolicyFile, &AttConfigurationSEVSNP)
|
||||
unmarshalOptions := protojson.UnmarshalOptions{DiscardUnknown: true}
|
||||
|
||||
err = unmarshalOptions.Unmarshal(attestationPolicyFile, config.AttestationPolicy.SnpCheck)
|
||||
require.NoError(t, err)
|
||||
|
||||
AttConfigurationSEVSNP.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
|
||||
AttConfigurationSEVSNP.Policy.FamilyId = rr.Report.FamilyId
|
||||
AttConfigurationSEVSNP.Policy.ImageId = rr.Report.ImageId
|
||||
AttConfigurationSEVSNP.Policy.Measurement = rr.Report.Measurement
|
||||
AttConfigurationSEVSNP.Policy.HostData = rr.Report.HostData
|
||||
AttConfigurationSEVSNP.Policy.ReportIdMa = rr.Report.ReportIdMa
|
||||
AttConfigurationSEVSNP.RootOfTrust.ProductLine = sevProductNameMilan
|
||||
config.AttestationPolicy.SnpCheck.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN}
|
||||
config.AttestationPolicy.SnpCheck.Policy.FamilyId = rr.Report.FamilyId
|
||||
config.AttestationPolicy.SnpCheck.Policy.ImageId = rr.Report.ImageId
|
||||
config.AttestationPolicy.SnpCheck.Policy.Measurement = rr.Report.Measurement
|
||||
config.AttestationPolicy.SnpCheck.Policy.HostData = rr.Report.HostData
|
||||
config.AttestationPolicy.SnpCheck.Policy.ReportIdMa = rr.Report.ReportIdMa
|
||||
config.AttestationPolicy.SnpCheck.RootOfTrust.ProductLine = sevProductNameMilan
|
||||
|
||||
return file, rr.Report.ReportData
|
||||
return rr, rr.Report.ReportData
|
||||
}
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package vtpm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-tpm-tools/client"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/google/go-tpm-tools/proto/tpm"
|
||||
"github.com/google/go-tpm-tools/server"
|
||||
"github.com/google/go-tpm/legacy/tpm2"
|
||||
"github.com/google/go-tpm/tpmutil"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
eventLog = "/sys/kernel/security/tpm0/binary_bios_measurements"
|
||||
Nonce = 32
|
||||
PCR15 = 15
|
||||
Hash256 = 32
|
||||
Hash384 = 48
|
||||
)
|
||||
|
||||
var (
|
||||
ExternalTPM io.ReadWriteCloser
|
||||
ErrNoHashAlgo = errors.New("hash algo is not supported")
|
||||
)
|
||||
|
||||
type tpmWrapper struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (et tpmWrapper) EventLog() ([]byte, error) {
|
||||
return os.ReadFile(eventLog)
|
||||
}
|
||||
|
||||
func OpenTpm() (io.ReadWriteCloser, error) {
|
||||
if ExternalTPM != nil {
|
||||
return tpmWrapper{ExternalTPM}, nil
|
||||
}
|
||||
|
||||
tw := tpmWrapper{}
|
||||
var err error
|
||||
|
||||
tw.ReadWriteCloser, err = tpm2.OpenTPM("/dev/tpmrm0")
|
||||
if os.IsNotExist(err) {
|
||||
tw.ReadWriteCloser, err = tpm2.OpenTPM("/dev/tpm0")
|
||||
}
|
||||
|
||||
return tw, err
|
||||
}
|
||||
|
||||
func ExtendPCR(pcrIndex int, value []byte) error {
|
||||
rwc, err := OpenTpm()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rwc.Close()
|
||||
|
||||
fixedSha256Hash := sha3.Sum256(value)
|
||||
if err := tpm2.PCRExtend(rwc, tpmutil.Handle(pcrIndex), tpm2.AlgSHA256, fixedSha256Hash[:], ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fixedSha384Hash := sha3.Sum384(value)
|
||||
if err := tpm2.PCRExtend(rwc, tpmutil.Handle(pcrIndex), tpm2.AlgSHA384, fixedSha384Hash[:], ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Attest(teeNonce []byte, vTPMNonce []byte, teeAttestaion bool) ([]byte, error) {
|
||||
attestation, err := fetchVTPMQuote(vTPMNonce)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
if teeAttestaion {
|
||||
attestation, err = addTEEAttestation(attestation, teeNonce)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return marshalQuote(attestation)
|
||||
}
|
||||
|
||||
func FetchATLSQuote(pubKey, teeNonce, vTPMNonce []byte) ([]byte, error) {
|
||||
attestation, err := fetchVTPMQuote(vTPMNonce)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
reportData, err := createTEEAttestationReportNonce(pubKey, attestation.GetAkPub(), teeNonce)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
attestation, err = addTEEAttestation(attestation, reportData)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
return marshalQuote(attestation)
|
||||
}
|
||||
|
||||
func VTPMVerify(quote []byte, pubKeyTLS []byte, teeNonce []byte, vtpmNonce []byte) error {
|
||||
attestation := &attest.Attestation{}
|
||||
|
||||
err := proto.Unmarshal(quote, attestation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to unmarshal quote: %v", err)
|
||||
}
|
||||
|
||||
ak := attestation.GetAkPub()
|
||||
pub, err := tpm2.DecodePublic(ak)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cryptoPub, err := pub.Key()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reportData, err := createTEEAttestationReportNonce(pubKeyTLS, ak, teeNonce)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fail to calculate report data: %v", err)
|
||||
}
|
||||
|
||||
if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), reportData); err != nil {
|
||||
return fmt.Errorf("failed to verify TEE attestation report: %v", err)
|
||||
}
|
||||
|
||||
_, err = server.VerifyAttestation(attestation, server.VerifyOpts{Nonce: vtpmNonce, TrustedAKs: []crypto.PublicKey{cryptoPub}})
|
||||
if err != nil {
|
||||
return fmt.Errorf("verifying attestation: %w", err)
|
||||
}
|
||||
|
||||
s256, s384 := calculatePCRTLSKey(pubKeyTLS)
|
||||
|
||||
if err := checkExpectedPCRValues(attestation, s256, s384); err != nil {
|
||||
return fmt.Errorf("PCR values do not match expected PCR values: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func publicKeyToBytes(pubKey interface{}) ([]byte, error) {
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return derBytes, nil
|
||||
}
|
||||
|
||||
func createTEEAttestationReportNonce(pubKeyTLS []byte, ak []byte, nonce []byte) ([]byte, error) {
|
||||
pub, err := tpm2.DecodePublic(ak)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
cryptoPub, err := pub.Key()
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
pubKeyBytes, err := publicKeyToBytes(cryptoPub)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
|
||||
reportData := append(append(pubKeyTLS, pubKeyBytes...), nonce...)
|
||||
hash := sha3.Sum512(reportData)
|
||||
|
||||
return hash[:], nil
|
||||
}
|
||||
|
||||
func marshalQuote(attestation *attest.Attestation) ([]byte, error) {
|
||||
out, err := proto.Marshal(attestation)
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("failed to marshal vTPM attestation report: %v", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func fetchVTPMQuote(nonce []byte) (*attest.Attestation, error) {
|
||||
rwc, err := OpenTpm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rwc.Close()
|
||||
|
||||
attestationKey, err := client.AttestationKeyRSA(rwc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create attestation key: %v", err)
|
||||
}
|
||||
defer attestationKey.Close()
|
||||
|
||||
var fixedNonce [Nonce]byte
|
||||
copy(fixedNonce[:], nonce)
|
||||
attestOpts := client.AttestOpts{}
|
||||
attestOpts.Nonce = fixedNonce[:]
|
||||
|
||||
attestOpts.TCGEventLog, err = client.GetEventLog(rwc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve TCG Event Log: %w", err)
|
||||
}
|
||||
|
||||
attestation, err := attestationKey.Attest(attestOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to collect attestation report: %v", err)
|
||||
}
|
||||
|
||||
return attestation, nil
|
||||
}
|
||||
|
||||
func addTEEAttestation(attestation *attest.Attestation, nonce []byte) (*attest.Attestation, error) {
|
||||
rawTeeAttestation, err := quoteprovider.FetchAttestation(nonce)
|
||||
if err != nil {
|
||||
return attestation, fmt.Errorf("failed to fetch TEE attestation report: %v", err)
|
||||
}
|
||||
|
||||
extReport, err := abi.ReportCertsToProto(rawTeeAttestation)
|
||||
if err != nil {
|
||||
return attestation, fmt.Errorf("failed to export the TEE report: %v", err)
|
||||
}
|
||||
attestation.TeeAttestation = &attest.Attestation_SevSnpAttestation{
|
||||
SevSnpAttestation: extReport,
|
||||
}
|
||||
|
||||
return attestation, nil
|
||||
}
|
||||
|
||||
func checkExpectedPCRValues(attestation *attest.Attestation, ePcr256 []byte, ePcr384 []byte) error {
|
||||
quotes := attestation.GetQuotes()
|
||||
for i := range quotes {
|
||||
quote := quotes[i]
|
||||
var pcrMap map[string]string
|
||||
var pcr15 []byte
|
||||
switch quote.Pcrs.Hash {
|
||||
case tpm.HashAlgo_SHA256:
|
||||
pcrMap = config.AttestationPolicy.PcrConfig.PCRValues.Sha256
|
||||
pcr15 = ePcr256
|
||||
case tpm.HashAlgo_SHA384:
|
||||
pcrMap = config.AttestationPolicy.PcrConfig.PCRValues.Sha384
|
||||
pcr15 = ePcr384
|
||||
default:
|
||||
return errors.Wrap(ErrNoHashAlgo, fmt.Errorf("algo: %s", tpm.HashAlgo_name[int32(quote.Pcrs.Hash)]))
|
||||
}
|
||||
|
||||
pcr15Index := uint32(15)
|
||||
if !bytes.Equal(quote.Pcrs.Pcrs[pcr15Index], pcr15) {
|
||||
return fmt.Errorf("for algo %s PCR[15] expected %s but found %s", tpm.HashAlgo_name[int32(quote.Pcrs.Hash)], hex.EncodeToString(pcr15), hex.EncodeToString(quote.Pcrs.Pcrs[pcr15Index]))
|
||||
}
|
||||
|
||||
for i, v := range pcrMap {
|
||||
index, err := strconv.ParseInt(i, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error converting PCR index to int32: %v\n", err)
|
||||
}
|
||||
value, err := hex.DecodeString(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error converting PCR value to byte: %v\n", err)
|
||||
}
|
||||
if !bytes.Equal(quote.Pcrs.Pcrs[uint32(index)], value) {
|
||||
return fmt.Errorf("for algo %s PCR[%d] expected %s but found %s", tpm.HashAlgo_name[int32(quote.Pcrs.Hash)], index, hex.EncodeToString(value), hex.EncodeToString(quote.Pcrs.Pcrs[uint32(index)]))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return SHA256 and SHA384 values of the input public key.
|
||||
func calculatePCRTLSKey(pubKey []byte) ([]byte, []byte) {
|
||||
init256 := make([]byte, Hash256)
|
||||
init384 := make([]byte, Hash384)
|
||||
|
||||
key256 := sha3.Sum256(pubKey)
|
||||
key384 := sha3.Sum384(pubKey)
|
||||
|
||||
pcrValue256 := append(init256, key256[:]...)
|
||||
pcrValue384 := append(init384, key384[:]...)
|
||||
|
||||
newPcr256 := sha256.Sum256(pcrValue256)
|
||||
newPcr384 := sha512.Sum384(pcrValue384)
|
||||
|
||||
return newPcr256[:], newPcr384[:]
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/health"
|
||||
@@ -112,7 +113,7 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
},
|
||||
AttestedTLS: true,
|
||||
},
|
||||
err: pkggrpc.ErrAttestationPolicyMissing,
|
||||
err: config.ErrAttestationPolicyMissing,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error) {
|
||||
err := ReadAttestationPolicy(cfg.AttestationPolicy, "eprovider.AttConfigurationSEVSNP)
|
||||
err := config.ReadAttestationPolicy(cfg.AttestationPolicy, &config.AttestationPolicy)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
att "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
@@ -200,8 +201,9 @@ func TestClientSecure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReadAttestationPolicy(t *testing.T) {
|
||||
validJSON := `{"policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}`
|
||||
validJSON := `{"pcr_values":{"sha256":{"0":"123"},"sha384":{"0":"123"}},"policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}`
|
||||
invalidJSON := `{"invalid_json"`
|
||||
invalidJSONPCR := `{"pcr_values":{"sha256":{"0":true},"sha384":{"0":"123"}},"policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}`
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
@@ -219,19 +221,25 @@ func TestReadAttestationPolicy(t *testing.T) {
|
||||
name: "Invalid JSON",
|
||||
manifestPath: "invalid_manifest.json",
|
||||
fileContent: invalidJSON,
|
||||
err: ErrAttestationPolicyDecode,
|
||||
err: att.ErrAttestationPolicyDecode,
|
||||
},
|
||||
{
|
||||
name: "Non-existent file",
|
||||
manifestPath: "nonexistent.json",
|
||||
fileContent: "",
|
||||
err: errAttestationPolicyOpen,
|
||||
err: att.ErrAttestationPolicyOpen,
|
||||
},
|
||||
{
|
||||
name: "Empty manifest path",
|
||||
manifestPath: "",
|
||||
fileContent: "",
|
||||
err: ErrAttestationPolicyMissing,
|
||||
err: att.ErrAttestationPolicyMissing,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON PCR",
|
||||
manifestPath: "invalid_manifest.json",
|
||||
fileContent: invalidJSONPCR,
|
||||
err: att.ErrAttestationPolicyDecode,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -243,13 +251,13 @@ func TestReadAttestationPolicy(t *testing.T) {
|
||||
defer os.Remove(tt.manifestPath)
|
||||
}
|
||||
|
||||
config := check.Config{}
|
||||
err := ReadAttestationPolicy(tt.manifestPath, &config)
|
||||
config := att.Config{SnpCheck: &check.Config{}, PcrConfig: &att.PcrConfig{}}
|
||||
err := att.ReadAttestationPolicy(tt.manifestPath, &config)
|
||||
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
if tt.err == nil {
|
||||
assert.NotNil(t, config.Policy)
|
||||
assert.NotNil(t, config.RootOfTrust)
|
||||
assert.NotNil(t, config.SnpCheck.Policy)
|
||||
assert.NotNil(t, config.SnpCheck.RootOfTrust)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,12 +11,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
type security int
|
||||
@@ -37,9 +35,6 @@ const (
|
||||
var (
|
||||
errGrpcConnect = errors.New("failed to connect to grpc server")
|
||||
errGrpcClose = errors.New("failed to close grpc connection")
|
||||
errAttestationPolicyOpen = errors.New("failed to open Attestation Policy file")
|
||||
ErrAttestationPolicyMissing = errors.New("failed due to missing Attestation Policy file")
|
||||
ErrAttestationPolicyDecode = errors.New("failed to decode Attestation Policy file")
|
||||
errCertificateParse = errors.New("failed to parse x509 certificate")
|
||||
errAttVerification = errors.New("certificat is not sefl signed")
|
||||
errFailedToLoadClientCertKey = errors.New("failed to load client certificate and key")
|
||||
@@ -55,7 +50,7 @@ type BaseConfig struct {
|
||||
Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"`
|
||||
ClientCert string `env:"CLIENT_CERT" envDefault:""`
|
||||
ClientKey string `env:"CLIENT_KEY" envDefault:""`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
}
|
||||
|
||||
type AgentClientConfig struct {
|
||||
@@ -146,7 +141,9 @@ func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) {
|
||||
if err != nil {
|
||||
return nil, secure, err
|
||||
}
|
||||
|
||||
opts = append(opts, grpc.WithTransportCredentials(tc))
|
||||
opts = append(opts, grpc.WithContextDialer(CustomDialer))
|
||||
secure = withaTLS
|
||||
} else {
|
||||
conf := cfg.GetBaseConfig()
|
||||
@@ -198,20 +195,3 @@ func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.Tran
|
||||
|
||||
return tc, nil, secure
|
||||
}
|
||||
|
||||
func ReadAttestationPolicy(manifestPath string, attestationConfiguration *check.Config) error {
|
||||
if manifestPath != "" {
|
||||
manifest, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttestationPolicyOpen, err)
|
||||
}
|
||||
|
||||
if err := protojson.Unmarshal(manifest, attestationConfiguration); err != nil {
|
||||
return errors.Wrap(ErrAttestationPolicyDecode, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrAttestationPolicyMissing
|
||||
}
|
||||
|
||||
+6
-3
@@ -26,11 +26,12 @@ type SDK interface {
|
||||
Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error
|
||||
Data(ctx context.Context, dataset *os.File, filename string, privKey any) error
|
||||
Result(ctx context.Context, privKey any, resultFile *os.File) error
|
||||
Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error
|
||||
Attestation(ctx context.Context, reportData [size64]byte, nonce [size32]byte, attType int, attestationFile *os.File) error
|
||||
}
|
||||
|
||||
const (
|
||||
size64 = 64
|
||||
size32 = 32
|
||||
algoProgressBarDescription = "Uploading algorithm"
|
||||
dataProgressBarDescription = "Uploading data"
|
||||
resultProgressDescription = "Downloading result"
|
||||
@@ -120,9 +121,11 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any, resultFile *os.Fil
|
||||
return pb.ReceiveResult(resultProgressDescription, fileSize, stream, resultFile)
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error {
|
||||
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, nonce [size32]byte, attType int, attestationFile *os.File) error {
|
||||
request := &agent.AttestationRequest{
|
||||
ReportData: reportData[:],
|
||||
TeeNonce: reportData[:],
|
||||
VtpmNonce: nonce[:],
|
||||
Type: int32(attType),
|
||||
}
|
||||
|
||||
stream, err := sdk.client.Attestation(ctx, request)
|
||||
|
||||
+15
-7
@@ -19,6 +19,8 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
@@ -364,6 +366,7 @@ func TestAttestation(t *testing.T) {
|
||||
resultConsumer1Key, _ := generateKeys(t, "ed25519")
|
||||
|
||||
reportData := make([]byte, 64)
|
||||
nonce := make([]byte, 64)
|
||||
report := []byte{
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
@@ -385,7 +388,8 @@ func TestAttestation(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
userKey any
|
||||
reportData [agent.ReportDataSize]byte
|
||||
reportData [quoteprovider.Nonce]byte
|
||||
nonce [vtpm.Nonce]byte
|
||||
response *agent.AttestationResponse
|
||||
svcRes []byte
|
||||
err error
|
||||
@@ -393,7 +397,8 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "fetch attestation report successfully",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [agent.ReportDataSize]byte(reportData),
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: report,
|
||||
},
|
||||
@@ -403,7 +408,8 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "fetch attestation report with different key type",
|
||||
userKey: resultConsumer1Key,
|
||||
reportData: [agent.ReportDataSize]byte(reportData),
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: report,
|
||||
},
|
||||
@@ -413,7 +419,8 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "failed to fetch attestation report",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [agent.ReportDataSize]byte(reportData),
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: []byte{},
|
||||
},
|
||||
@@ -422,7 +429,8 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "invalid report data",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [agent.ReportDataSize]byte{},
|
||||
reportData: [quoteprovider.Nonce]byte{},
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: []byte{},
|
||||
},
|
||||
@@ -433,7 +441,7 @@ func TestAttestation(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
||||
svcCall := svc.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
||||
|
||||
file, err := os.CreateTemp("", "attestation")
|
||||
require.NoError(t, err)
|
||||
@@ -442,7 +450,7 @@ func TestAttestation(t *testing.T) {
|
||||
os.Remove(file.Name())
|
||||
})
|
||||
|
||||
err = sdk.Attestation(context.Background(), tc.reportData, file)
|
||||
err = sdk.Attestation(context.Background(), tc.reportData, tc.nonce, 0, file)
|
||||
|
||||
require.NoError(t, file.Close())
|
||||
|
||||
|
||||
+12
-10
@@ -74,17 +74,17 @@ func (_c *SDK_Algo_Call) RunAndReturn(run func(context.Context, *os.File, *os.Fi
|
||||
return _c
|
||||
}
|
||||
|
||||
// Attestation provides a mock function with given fields: ctx, reportData, attestationFile
|
||||
func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, attestationFile *os.File) error {
|
||||
ret := _m.Called(ctx, reportData, attestationFile)
|
||||
// Attestation provides a mock function with given fields: ctx, reportData, nonce, attType, attestationFile
|
||||
func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType int, attestationFile *os.File) error {
|
||||
ret := _m.Called(ctx, reportData, nonce, attType, attestationFile)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Attestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte, *os.File) error); ok {
|
||||
r0 = rf(ctx, reportData, attestationFile)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, int, *os.File) error); ok {
|
||||
r0 = rf(ctx, reportData, nonce, attType, attestationFile)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
@@ -100,14 +100,16 @@ type SDK_Attestation_Call struct {
|
||||
// Attestation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - reportData [64]byte
|
||||
// - nonce [32]byte
|
||||
// - attType int
|
||||
// - attestationFile *os.File
|
||||
func (_e *SDK_Expecter) Attestation(ctx interface{}, reportData interface{}, attestationFile interface{}) *SDK_Attestation_Call {
|
||||
return &SDK_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, attestationFile)}
|
||||
func (_e *SDK_Expecter) Attestation(ctx interface{}, reportData interface{}, nonce interface{}, attType interface{}, attestationFile interface{}) *SDK_Attestation_Call {
|
||||
return &SDK_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, nonce, attType, attestationFile)}
|
||||
}
|
||||
|
||||
func (_c *SDK_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, attestationFile *os.File)) *SDK_Attestation_Call {
|
||||
func (_c *SDK_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType int, attestationFile *os.File)) *SDK_Attestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].([64]byte), args[2].(*os.File))
|
||||
run(args[0].(context.Context), args[1].([64]byte), args[2].([32]byte), args[3].(int), args[4].(*os.File))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@@ -117,7 +119,7 @@ func (_c *SDK_Attestation_Call) Return(_a0 error) *SDK_Attestation_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *SDK_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, *os.File) error) *SDK_Attestation_Call {
|
||||
func (_c *SDK_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, [32]byte, int, *os.File) error) *SDK_Attestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user