COCOS-326 - Add vTPM support to CoCoS (#376)
CI / checkproto (push) Has been cancelled
CI / ci (push) Has been cancelled
Rust CI Pipeline / rust-check (push) Has been cancelled

* manager, cli and agent vtpm support

* rebase and changed atls for vtpm

* deleted unused code

* changed chekproto.yaml script so it find the manager proto file correctly

* fixe manager proto version

* fix agent tests

* fix server agent test

* fix attestation test

* fix attestation test gofumpt

* created dummy RWC for TPM

* fix comment

* add default PCR values

* rebase main

* fix rust ci and missing header

* changed embedded  attestation to VMPL 2

* fix unused impot

* fix pkg test

* address attestation type

* fix agent attestation test

* add prc15 check

* fix comments

* fix cli tests

* add doc

* add mock for LeveledQuoteProvider when SEV-SNP device is not found

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix manager reading attestation policy

* refactor PCR value checks and update attestation policy values

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix tests for sev and grpc

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
Co-authored-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Danko Miladinovic
2025-03-07 16:36:47 +01:00
committed by GitHub
parent fa26573643
commit 67f939fc66
57 changed files with 1289 additions and 626 deletions
+54 -35
View File
@@ -3,8 +3,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.0
// protoc v5.29.0
// protoc-gen-go v1.36.4
// protoc v5.29.3
// source: agent/agent.proto
package agent
@@ -14,6 +14,7 @@ import (
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
@@ -281,7 +282,9 @@ func (x *ResultResponse) GetFile() []byte {
type AttestationRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ReportData []byte `protobuf:"bytes,1,opt,name=report_data,json=reportData,proto3" json:"report_data,omitempty"` // Should be of length 64.
TeeNonce []byte `protobuf:"bytes,1,opt,name=teeNonce,proto3" json:"teeNonce,omitempty"` // Should be less or equal 64 bytes.
VtpmNonce []byte `protobuf:"bytes,2,opt,name=vtpmNonce,proto3" json:"vtpmNonce,omitempty"` // Should be less or equal 32 bytes.
Type int32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -316,13 +319,27 @@ func (*AttestationRequest) Descriptor() ([]byte, []int) {
return file_agent_agent_proto_rawDescGZIP(), []int{6}
}
func (x *AttestationRequest) GetReportData() []byte {
func (x *AttestationRequest) GetTeeNonce() []byte {
if x != nil {
return x.ReportData
return x.TeeNonce
}
return nil
}
func (x *AttestationRequest) GetVtpmNonce() []byte {
if x != nil {
return x.VtpmNonce
}
return nil
}
func (x *AttestationRequest) GetType() int32 {
if x != nil {
return x.Type
}
return 0
}
type AttestationResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
@@ -369,7 +386,7 @@ func (x *AttestationResponse) GetFile() []byte {
var File_agent_agent_proto protoreflect.FileDescriptor
var file_agent_agent_proto_rawDesc = []byte{
var file_agent_agent_proto_rawDesc = string([]byte{
0x0a, 0x11, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x12, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x22, 0x4f, 0x0a, 0x0b, 0x41, 0x6c,
0x67, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6c, 0x67,
@@ -386,40 +403,43 @@ var file_agent_agent_proto_rawDesc = []byte{
0x22, 0x0f, 0x0a, 0x0d, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x22, 0x24, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28,
0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x35, 0x0a, 0x12, 0x41, 0x74, 0x74, 0x65, 0x73,
0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1f, 0x0a,
0x0b, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01,
0x28, 0x0c, 0x52, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x44, 0x61, 0x74, 0x61, 0x22, 0x29,
0x0a, 0x13, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67,
0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c,
0x67, 0x6f, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41,
0x6c, 0x67, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12,
0x33, 0x0a, 0x04, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e,
0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67,
0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x28, 0x01, 0x12, 0x39, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14,
0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73,
0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12,
0x48, 0x0a, 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19,
0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e,
0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61,
0x67, 0x65, 0x6e, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x62, 0x0a, 0x12, 0x41, 0x74, 0x74, 0x65, 0x73,
0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a,
0x08, 0x74, 0x65, 0x65, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x08, 0x74, 0x65, 0x65, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x76, 0x74, 0x70,
0x6d, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x76, 0x74,
0x70, 0x6d, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18,
0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x29, 0x0a, 0x13, 0x41,
0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67, 0x65, 0x6e, 0x74,
0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c, 0x67, 0x6f, 0x12,
0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, 0x33, 0x0a, 0x04,
0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74,
0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74,
0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28,
0x01, 0x12, 0x39, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14, 0x2e, 0x61, 0x67,
0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x48, 0x0a, 0x0b,
0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x2e, 0x61, 0x67,
0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41,
0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, 0x67, 0x65, 0x6e,
0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
})
var (
file_agent_agent_proto_rawDescOnce sync.Once
file_agent_agent_proto_rawDescData = file_agent_agent_proto_rawDesc
file_agent_agent_proto_rawDescData []byte
)
func file_agent_agent_proto_rawDescGZIP() []byte {
file_agent_agent_proto_rawDescOnce.Do(func() {
file_agent_agent_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_agent_proto_rawDescData)
file_agent_agent_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_agent_proto_rawDesc), len(file_agent_agent_proto_rawDesc)))
})
return file_agent_agent_proto_rawDescData
}
@@ -460,7 +480,7 @@ func file_agent_agent_proto_init() {
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_agent_agent_proto_rawDesc,
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_agent_proto_rawDesc), len(file_agent_agent_proto_rawDesc)),
NumEnums: 0,
NumMessages: 8,
NumExtensions: 0,
@@ -471,7 +491,6 @@ func file_agent_agent_proto_init() {
MessageInfos: file_agent_agent_proto_msgTypes,
}.Build()
File_agent_agent_proto = out.File
file_agent_agent_proto_rawDesc = nil
file_agent_agent_proto_goTypes = nil
file_agent_agent_proto_depIdxs = nil
}
+3 -1
View File
@@ -36,7 +36,9 @@ message ResultResponse {
}
message AttestationRequest {
bytes report_data = 1; // Should be of length 64.
bytes teeNonce = 1; // Should be less or equal 64 bytes.
bytes vtpmNonce = 2; // Should be less or equal 32 bytes.
int32 type = 3;
}
message AttestationResponse {
+1 -1
View File
@@ -4,7 +4,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.29.0
// - protoc v5.29.3
// source: agent/agent.proto
package agent
+2 -1
View File
@@ -7,6 +7,7 @@ import (
"github.com/go-kit/kit/endpoint"
"github.com/ultravioletrs/cocos/agent"
config "github.com/ultravioletrs/cocos/pkg/attestation"
)
func algoEndpoint(svc agent.Service) endpoint.Endpoint {
@@ -70,7 +71,7 @@ func attestationEndpoint(svc agent.Service) endpoint.Endpoint {
if err := req.validate(); err != nil {
return attestationRes{}, err
}
file, err := svc.Attestation(ctx, req.ReportData)
file, err := svc.Attestation(ctx, req.TeeNonce, req.VtpmNonce, config.AttestationType(req.AttType))
if err != nil {
return attestationRes{}, err
}
+5 -4
View File
@@ -9,6 +9,7 @@ import (
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/mocks"
config "github.com/ultravioletrs/cocos/pkg/attestation"
"golang.org/x/crypto/sha3"
)
@@ -141,11 +142,11 @@ func TestAttestationEndpoint(t *testing.T) {
}{
{
name: "Success",
req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))},
req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: config.SNP},
},
{
name: "Service Error",
req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))},
req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: config.SNP},
expectedErr: true,
},
}
@@ -153,9 +154,9 @@ func TestAttestationEndpoint(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == svcErr {
svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, errors.New("")).Once()
svc.On("Attestation", context.Background(), tt.req.TeeNonce, tt.req.VtpmNonce, tt.req.AttType).Return([]byte{}, errors.New("")).Once()
} else {
svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, nil).Once()
svc.On("Attestation", context.Background(), tt.req.TeeNonce, tt.req.VtpmNonce, tt.req.AttType).Return([]byte{}, nil).Once()
}
endpoint := attestationEndpoint(svc)
res, err := endpoint(context.Background(), tt.req)
+13 -2
View File
@@ -4,6 +4,10 @@ package grpc
import (
"errors"
config "github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
)
type algoReq struct {
@@ -38,9 +42,16 @@ func (req resultReq) validate() error {
}
type attestationReq struct {
ReportData [64]byte
TeeNonce [quoteprovider.Nonce]byte
VtpmNonce [vtpm.Nonce]byte
AttType config.AttestationType
}
func (req attestationReq) validate() error {
return nil
switch req.AttType {
case config.SNP, config.VTPM, config.SNPvTPM:
return nil
default:
return errors.New("invalid attestation type in attestation request")
}
}
+21 -3
View File
@@ -11,6 +11,9 @@ import (
"github.com/go-kit/kit/transport/grpc"
"github.com/ultravioletrs/cocos/agent"
config "github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
@@ -21,6 +24,11 @@ const (
FileSizeKey = "file-size"
)
var (
ErrTEENonceLength = errors.New("malformed report data, expect less or equal to 64 bytes")
ErrVTpmNonceLength = errors.New("malformed vTPM nonce, expect less or equal to 32 bytes")
)
var _ agent.AgentServiceServer = (*grpcServer)(nil)
type grpcServer struct {
@@ -96,10 +104,20 @@ func encodeResultResponse(_ context.Context, response interface{}) (interface{},
func decodeAttestationRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*agent.AttestationRequest)
if len(req.ReportData) != agent.ReportDataSize {
return nil, errors.New("malformed report data, expect 64 bytes")
var reportData [quoteprovider.Nonce]byte
var nonce [vtpm.Nonce]byte
if len(req.TeeNonce) > quoteprovider.Nonce {
return nil, ErrTEENonceLength
}
return attestationReq{ReportData: [agent.ReportDataSize]byte(req.ReportData)}, nil
if len(req.VtpmNonce) > vtpm.Nonce {
return nil, ErrVTpmNonceLength
}
copy(reportData[:], req.TeeNonce)
copy(nonce[:], req.VtpmNonce)
return attestationReq{TeeNonce: reportData, VtpmNonce: nonce, AttType: config.AttestationType(req.Type)}, nil
}
func encodeAttestationResponse(_ context.Context, response interface{}) (interface{}, error) {
+11 -6
View File
@@ -11,6 +11,9 @@ import (
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/mocks"
config "github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
@@ -151,10 +154,12 @@ func TestAttestation(t *testing.T) {
mockStream := &MockAgentService_AttestationServer{ctx: context.Background()}
mockStream.On("Send", mock.AnythingOfType("*agent.AttestationResponse")).Return(nil)
reportData := [agent.ReportDataSize]byte{}
mockService.On("Attestation", mock.Anything, reportData).Return([]byte("attestation data"), nil)
reportData := [quoteprovider.Nonce]byte{}
vtpmNonce := [vtpm.Nonce]byte{}
attestationType := config.SNP
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return([]byte("attestation data"), nil)
err := server.Attestation(&agent.AttestationRequest{ReportData: reportData[:]}, mockStream)
err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:]}, mockStream)
assert.NoError(t, err)
mockService.AssertExpectations(t)
@@ -199,11 +204,11 @@ func TestEncodeResultResponse(t *testing.T) {
}
func TestDecodeAttestationRequest(t *testing.T) {
reportData := [agent.ReportDataSize]byte{}
req := &agent.AttestationRequest{ReportData: reportData[:]}
nonce := [quoteprovider.Nonce]byte{}
req := &agent.AttestationRequest{TeeNonce: nonce[:]}
decoded, err := decodeAttestationRequest(context.Background(), req)
assert.NoError(t, err)
assert.Equal(t, attestationReq{ReportData: reportData}, decoded)
assert.Equal(t, attestationReq{TeeNonce: nonce}, decoded)
}
func TestEncodeAttestationResponse(t *testing.T) {
+5 -2
View File
@@ -13,6 +13,9 @@ import (
"time"
"github.com/ultravioletrs/cocos/agent"
config "github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
)
var _ agent.Service = (*loggingMiddleware)(nil)
@@ -103,7 +106,7 @@ func (lm *loggingMiddleware) Result(ctx context.Context) (response []byte, err e
return lm.svc.Result(ctx)
}
func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [agent.ReportDataSize]byte) (response []byte, err error) {
func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) (response []byte, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method Attestation took %s to complete", time.Since(begin))
if err != nil {
@@ -113,5 +116,5 @@ func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [agent.
lm.logger.Info(fmt.Sprintf("%s without errors", message))
}(time.Now())
return lm.svc.Attestation(ctx, reportData)
return lm.svc.Attestation(ctx, reportData, nonce, attType)
}
+5 -2
View File
@@ -12,6 +12,9 @@ import (
"github.com/go-kit/kit/metrics"
"github.com/ultravioletrs/cocos/agent"
config "github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
)
var _ agent.Service = (*metricsMiddleware)(nil)
@@ -89,11 +92,11 @@ func (ms *metricsMiddleware) Result(ctx context.Context) ([]byte, error) {
return ms.svc.Result(ctx)
}
func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [agent.ReportDataSize]byte) ([]byte, error) {
func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) ([]byte, error) {
defer func(begin time.Time) {
ms.counter.With("method", "attestation").Add(1)
ms.latency.With("method", "attestation").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.Attestation(ctx, reportData)
return ms.svc.Attestation(ctx, reportData, nonce, attType)
}
+1 -1
View File
@@ -71,7 +71,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
return err
}
qp, err := quoteprovider.GetQuoteProvider()
qp, err := quoteprovider.GetLeveledQuoteProvider()
if err != nil {
as.logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error()))
return err
+19 -16
View File
@@ -6,9 +6,10 @@
package mocks
import (
context "context"
agent "github.com/ultravioletrs/cocos/agent"
config "github.com/ultravioletrs/cocos/pkg/attestation"
context "context"
mock "github.com/stretchr/testify/mock"
)
@@ -73,9 +74,9 @@ func (_c *Service_Algo_Call) RunAndReturn(run func(context.Context, agent.Algori
return _c
}
// Attestation provides a mock function with given fields: ctx, reportData
func (_m *Service) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) {
ret := _m.Called(ctx, reportData)
// Attestation provides a mock function with given fields: ctx, reportData, nonce, attType
func (_m *Service) Attestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType config.AttestationType) ([]byte, error) {
ret := _m.Called(ctx, reportData, nonce, attType)
if len(ret) == 0 {
panic("no return value specified for Attestation")
@@ -83,19 +84,19 @@ func (_m *Service) Attestation(ctx context.Context, reportData [64]byte) ([]byte
var r0 []byte
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok {
return rf(ctx, reportData)
if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, config.AttestationType) ([]byte, error)); ok {
return rf(ctx, reportData, nonce, attType)
}
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok {
r0 = rf(ctx, reportData)
if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, config.AttestationType) []byte); ok {
r0 = rf(ctx, reportData, nonce, attType)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok {
r1 = rf(ctx, reportData)
if rf, ok := ret.Get(1).(func(context.Context, [64]byte, [32]byte, config.AttestationType) error); ok {
r1 = rf(ctx, reportData, nonce, attType)
} else {
r1 = ret.Error(1)
}
@@ -111,13 +112,15 @@ type Service_Attestation_Call struct {
// Attestation is a helper method to define mock.On call
// - ctx context.Context
// - reportData [64]byte
func (_e *Service_Expecter) Attestation(ctx interface{}, reportData interface{}) *Service_Attestation_Call {
return &Service_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData)}
// - nonce [32]byte
// - attType config.AttestationType
func (_e *Service_Expecter) Attestation(ctx interface{}, reportData interface{}, nonce interface{}, attType interface{}) *Service_Attestation_Call {
return &Service_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, nonce, attType)}
}
func (_c *Service_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte)) *Service_Attestation_Call {
func (_c *Service_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType config.AttestationType)) *Service_Attestation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([64]byte))
run(args[0].(context.Context), args[1].([64]byte), args[2].([32]byte), args[3].(config.AttestationType))
})
return _c
}
@@ -127,7 +130,7 @@ func (_c *Service_Attestation_Call) Return(_a0 []byte, _a1 error) *Service_Attes
return _c
}
func (_c *Service_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte) ([]byte, error)) *Service_Attestation_Call {
func (_c *Service_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, [32]byte, config.AttestationType) ([]byte, error)) *Service_Attestation_Call {
_c.Call.Return(run)
return _c
}
+41 -20
View File
@@ -23,6 +23,9 @@ import (
"github.com/ultravioletrs/cocos/agent/events"
"github.com/ultravioletrs/cocos/agent/statemachine"
"github.com/ultravioletrs/cocos/internal"
config "github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"golang.org/x/crypto/sha3"
)
@@ -69,8 +72,6 @@ const (
)
const (
// ReportDataSize is the size of the report data expected by the attestation service.
ReportDataSize = 64
algoFilePermission = 0o700
)
@@ -99,6 +100,8 @@ var (
ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers")
// ErrAttestationFailed attestation failed.
ErrAttestationFailed = errors.New("failed to get raw quote")
// ErrAttType indicates that the attestation type that is requested does not exist or is not supported.
ErrAttestationType = errors.New("attestation type does not exist or is not supported")
)
// Service specifies an API that must be fullfiled by the domain service
@@ -109,28 +112,29 @@ type Service interface {
Algo(ctx context.Context, algorithm Algorithm) error
Data(ctx context.Context, dataset Dataset) error
Result(ctx context.Context) ([]byte, error)
Attestation(ctx context.Context, reportData [ReportDataSize]byte) ([]byte, error)
Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) ([]byte, error)
State() string
}
type agentService struct {
mu sync.Mutex
computation Computation // Holds the current computation request details.
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
result []byte // Stores the result of the computation.
sm statemachine.StateMachine // Manages the state transitions of the agent service.
runError error // Stores any error encountered during the computation run.
eventSvc events.Service // Service for publishing events related to computation.
quoteProvider client.QuoteProvider // Provider for generating attestation quotes.
logger *slog.Logger // Logger for the agent service.
resultsConsumed bool // Indicates if the results have been consumed.
cancel context.CancelFunc // Cancels the computation context.
computation Computation // Holds the current computation request details.
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
result []byte // Stores the result of the computation.
sm statemachine.StateMachine // Manages the state transitions of the agent service.
runError error // Stores any error encountered during the computation run.
eventSvc events.Service // Service for publishing events related to computation.
quoteProvider client.LeveledQuoteProvider // Provider for generating attestation quotes.
logger *slog.Logger // Logger for the agent service.
resultsConsumed bool // Indicates if the results have been consumed.
cancel context.CancelFunc // Cancels the computation context.
vmpl int // VMPL at which the Agent is running.
}
var _ Service = (*agentService)(nil)
// New instantiates the agent service implementation.
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quoteProvider client.QuoteProvider) Service {
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quoteProvider client.LeveledQuoteProvider, vmlp int) Service {
sm := statemachine.NewStateMachine(Idle)
ctx, cancel := context.WithCancel(ctx)
svc := &agentService{
@@ -139,6 +143,7 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quot
quoteProvider: quoteProvider,
logger: logger,
cancel: cancel,
vmpl: vmlp,
}
transitions := []statemachine.Transition{
@@ -397,13 +402,29 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) {
return as.result, as.runError
}
func (as *agentService) Attestation(ctx context.Context, reportData [ReportDataSize]byte) ([]byte, error) {
rawQuote, err := as.quoteProvider.GetRawQuote(reportData)
if err != nil {
return []byte{}, err
func (as *agentService) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) ([]byte, error) {
switch attType {
case config.SNP:
rawQuote, err := as.quoteProvider.GetRawQuoteAtLevel(reportData, uint(as.vmpl))
if err != nil {
return []byte{}, err
}
return rawQuote, nil
case config.VTPM:
vTPMQuote, err := vtpm.Attest(reportData[:], nonce[:], false)
if err != nil {
return []byte{}, err
}
return vTPMQuote, nil
case config.SNPvTPM:
vTPMQuote, err := vtpm.Attest(reportData[:], nonce[:], true)
if err != nil {
return []byte{}, err
}
return vTPMQuote, nil
default:
return []byte{}, ErrAttestationType
}
return rawQuote, nil
}
func (as *agentService) runComputation(state statemachine.State) {
+17 -13
View File
@@ -22,6 +22,7 @@ import (
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
mocks2 "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc/metadata"
)
@@ -35,7 +36,7 @@ var (
const datasetFile = "iris.csv"
func TestAlgo(t *testing.T) {
qp, err := quoteprovider.GetQuoteProvider()
qp, err := quoteprovider.GetLeveledQuoteProvider()
require.NoError(t, err)
algo, err := os.ReadFile(algoPath)
@@ -120,7 +121,7 @@ func TestAlgo(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
svc := New(ctx, mglog.NewMock(), events, qp)
svc := New(ctx, mglog.NewMock(), events, qp, 0)
err := svc.InitComputation(ctx, testComputation(t))
require.NoError(t, err)
@@ -139,7 +140,7 @@ func TestAlgo(t *testing.T) {
}
func TestData(t *testing.T) {
qp, err := quoteprovider.GetQuoteProvider()
qp, err := quoteprovider.GetLeveledQuoteProvider()
require.NoError(t, err)
algo, err := os.ReadFile(algoPath)
@@ -215,7 +216,7 @@ func TestData(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
svc := New(ctx, mglog.NewMock(), events, qp)
svc := New(ctx, mglog.NewMock(), events, qp, 0)
err := svc.InitComputation(ctx, testComputation(t))
require.NoError(t, err)
@@ -240,7 +241,7 @@ func TestData(t *testing.T) {
}
func TestResult(t *testing.T) {
qp, err := quoteprovider.GetQuoteProvider()
qp, err := quoteprovider.GetLeveledQuoteProvider()
require.NoError(t, err)
cases := []struct {
@@ -323,23 +324,26 @@ func TestResult(t *testing.T) {
}
func TestAttestation(t *testing.T) {
qp := new(mocks2.QuoteProvider)
qp := new(mocks2.LeveledQuoteProvider)
cases := []struct {
name string
reportData [ReportDataSize]byte
reportData [quoteprovider.Nonce]byte
nonce [vtpm.Nonce]byte
rawQuote []uint8
err error
}{
{
name: "Test attestation successful",
reportData: generateReportData(),
nonce: [32]byte{},
rawQuote: make([]uint8, 0),
err: nil,
},
{
name: "Test attestation failed",
reportData: generateReportData(),
nonce: [32]byte{},
rawQuote: nil,
err: ErrAttestationFailed,
},
@@ -355,22 +359,22 @@ func TestAttestation(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
getQuote := qp.On("GetRawQuote", mock.Anything).Return(tc.rawQuote, tc.err)
getQuote := qp.On("GetRawQuoteAtLevel", mock.Anything, mock.Anything).Return(tc.rawQuote, tc.err)
if tc.err != ErrAttestationFailed {
getQuote = qp.On("GetRawQuote", mock.Anything).Return(tc.reportData, nil)
getQuote = qp.On("GetRawQuoteAtLevel", mock.Anything, mock.Anything).Return(tc.nonce, nil)
}
defer getQuote.Unset()
svc := New(ctx, mglog.NewMock(), events, qp)
svc := New(ctx, mglog.NewMock(), events, qp, 0)
time.Sleep(300 * time.Millisecond)
_, err := svc.Attestation(ctx, tc.reportData)
_, err := svc.Attestation(ctx, tc.reportData, tc.nonce, 0)
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
})
}
}
func generateReportData() [ReportDataSize]byte {
bytes := make([]byte, ReportDataSize)
func generateReportData() [quoteprovider.Nonce]byte {
bytes := make([]byte, quoteprovider.Nonce)
_, err := rand.Read(bytes)
if err != nil {
log.Fatalf("Failed to generate random bytes: %v", err)