diff --git a/.github/workflows/checkproto.yaml b/.github/workflows/checkproto.yaml index e00dd358..bc0b2804 100644 --- a/.github/workflows/checkproto.yaml +++ b/.github/workflows/checkproto.yaml @@ -33,8 +33,8 @@ jobs: - name: Set up protoc run: | - PROTOC_VERSION=29.0 - PROTOC_GEN_VERSION=v1.36.0 + PROTOC_VERSION=29.3 + PROTOC_GEN_VERSION=v1.36.4 PROTOC_GRPC_VERSION=v1.5.1 # Download and install protoc @@ -55,7 +55,7 @@ jobs: - name: Set up Cocos-AI run: | # Rename .pb.go files to .pb.go.tmp to prevent conflicts - for p in $(ls pkg/manager/*.pb.go); do + for p in $(ls manager/*.pb.go); do mv $p $p.tmp done @@ -67,7 +67,7 @@ jobs: make protoc # Compare generated Go files with the original ones - for p in $(ls pkg/manager/*.pb.go); do + for p in $(ls manager/*.pb.go); do if ! cmp -s $p $p.tmp; then echo "Proto file and generated Go file $p are out of sync!" exit 1 diff --git a/.golangci.yaml b/.golangci.yaml index 6ea9c222..25158461 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -4,6 +4,10 @@ run: issues: max-issues-per-linter: 10 max-same-issues: 10 + exclude-rules: + - linters: + - makezero + text: "with non-zero initialized length" linters-settings: importas: diff --git a/agent/agent.pb.go b/agent/agent.pb.go index 8a090d6b..19daa5b6 100644 --- a/agent/agent.pb.go +++ b/agent/agent.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.0 -// protoc v5.29.0 +// protoc-gen-go v1.36.4 +// protoc v5.29.3 // source: agent/agent.proto package agent @@ -14,6 +14,7 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -281,7 +282,9 @@ func (x *ResultResponse) GetFile() []byte { type AttestationRequest struct { state protoimpl.MessageState `protogen:"open.v1"` - ReportData []byte `protobuf:"bytes,1,opt,name=report_data,json=reportData,proto3" json:"report_data,omitempty"` // Should be of length 64. + TeeNonce []byte `protobuf:"bytes,1,opt,name=teeNonce,proto3" json:"teeNonce,omitempty"` // Should be less or equal 64 bytes. + VtpmNonce []byte `protobuf:"bytes,2,opt,name=vtpmNonce,proto3" json:"vtpmNonce,omitempty"` // Should be less or equal 32 bytes. + Type int32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -316,13 +319,27 @@ func (*AttestationRequest) Descriptor() ([]byte, []int) { return file_agent_agent_proto_rawDescGZIP(), []int{6} } -func (x *AttestationRequest) GetReportData() []byte { +func (x *AttestationRequest) GetTeeNonce() []byte { if x != nil { - return x.ReportData + return x.TeeNonce } return nil } +func (x *AttestationRequest) GetVtpmNonce() []byte { + if x != nil { + return x.VtpmNonce + } + return nil +} + +func (x *AttestationRequest) GetType() int32 { + if x != nil { + return x.Type + } + return 0 +} + type AttestationResponse struct { state protoimpl.MessageState `protogen:"open.v1"` File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"` @@ -369,7 +386,7 @@ func (x *AttestationResponse) GetFile() []byte { var File_agent_agent_proto protoreflect.FileDescriptor -var file_agent_agent_proto_rawDesc = []byte{ +var file_agent_agent_proto_rawDesc = string([]byte{ 0x0a, 0x11, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x22, 0x4f, 0x0a, 0x0b, 0x41, 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6c, 0x67, @@ -386,40 +403,43 @@ var file_agent_agent_proto_rawDesc = []byte{ 0x22, 0x0f, 0x0a, 0x0d, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x24, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x35, 0x0a, 0x12, 0x41, 0x74, 0x74, 0x65, 0x73, - 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1f, 0x0a, - 0x0b, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0c, 0x52, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x44, 0x61, 0x74, 0x61, 0x22, 0x29, - 0x0a, 0x13, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c, - 0x67, 0x6f, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, - 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, - 0x33, 0x0a, 0x04, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, - 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, - 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x28, 0x01, 0x12, 0x39, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14, - 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, - 0x48, 0x0a, 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, - 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e, - 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, - 0x67, 0x65, 0x6e, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} + 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x62, 0x0a, 0x12, 0x41, 0x74, 0x74, 0x65, 0x73, + 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, + 0x08, 0x74, 0x65, 0x65, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x08, 0x74, 0x65, 0x65, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x76, 0x74, 0x70, + 0x6d, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x76, 0x74, + 0x70, 0x6d, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x29, 0x0a, 0x13, 0x41, + 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67, 0x65, 0x6e, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c, 0x67, 0x6f, 0x12, + 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, 0x33, 0x0a, 0x04, + 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74, + 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, + 0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, + 0x01, 0x12, 0x39, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14, 0x2e, 0x61, 0x67, + 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x48, 0x0a, 0x0b, + 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x2e, 0x61, 0x67, + 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, + 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, 0x67, 0x65, 0x6e, + 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) var ( file_agent_agent_proto_rawDescOnce sync.Once - file_agent_agent_proto_rawDescData = file_agent_agent_proto_rawDesc + file_agent_agent_proto_rawDescData []byte ) func file_agent_agent_proto_rawDescGZIP() []byte { file_agent_agent_proto_rawDescOnce.Do(func() { - file_agent_agent_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_agent_proto_rawDescData) + file_agent_agent_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_agent_proto_rawDesc), len(file_agent_agent_proto_rawDesc))) }) return file_agent_agent_proto_rawDescData } @@ -460,7 +480,7 @@ func file_agent_agent_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_agent_agent_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_agent_proto_rawDesc), len(file_agent_agent_proto_rawDesc)), NumEnums: 0, NumMessages: 8, NumExtensions: 0, @@ -471,7 +491,6 @@ func file_agent_agent_proto_init() { MessageInfos: file_agent_agent_proto_msgTypes, }.Build() File_agent_agent_proto = out.File - file_agent_agent_proto_rawDesc = nil file_agent_agent_proto_goTypes = nil file_agent_agent_proto_depIdxs = nil } diff --git a/agent/agent.proto b/agent/agent.proto index f07b8f02..d9877426 100644 --- a/agent/agent.proto +++ b/agent/agent.proto @@ -36,7 +36,9 @@ message ResultResponse { } message AttestationRequest { - bytes report_data = 1; // Should be of length 64. + bytes teeNonce = 1; // Should be less or equal 64 bytes. + bytes vtpmNonce = 2; // Should be less or equal 32 bytes. + int32 type = 3; } message AttestationResponse { diff --git a/agent/agent_grpc.pb.go b/agent/agent_grpc.pb.go index 6459ecec..80f9211a 100644 --- a/agent/agent_grpc.pb.go +++ b/agent/agent_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.0 +// - protoc v5.29.3 // source: agent/agent.proto package agent diff --git a/agent/api/grpc/endpoint.go b/agent/api/grpc/endpoint.go index da41add6..95c43f71 100644 --- a/agent/api/grpc/endpoint.go +++ b/agent/api/grpc/endpoint.go @@ -7,6 +7,7 @@ import ( "github.com/go-kit/kit/endpoint" "github.com/ultravioletrs/cocos/agent" + config "github.com/ultravioletrs/cocos/pkg/attestation" ) func algoEndpoint(svc agent.Service) endpoint.Endpoint { @@ -70,7 +71,7 @@ func attestationEndpoint(svc agent.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return attestationRes{}, err } - file, err := svc.Attestation(ctx, req.ReportData) + file, err := svc.Attestation(ctx, req.TeeNonce, req.VtpmNonce, config.AttestationType(req.AttType)) if err != nil { return attestationRes{}, err } diff --git a/agent/api/grpc/endpoint_test.go b/agent/api/grpc/endpoint_test.go index 4ab3494d..b7436ffd 100644 --- a/agent/api/grpc/endpoint_test.go +++ b/agent/api/grpc/endpoint_test.go @@ -9,6 +9,7 @@ import ( "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/mocks" + config "github.com/ultravioletrs/cocos/pkg/attestation" "golang.org/x/crypto/sha3" ) @@ -141,11 +142,11 @@ func TestAttestationEndpoint(t *testing.T) { }{ { name: "Success", - req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))}, + req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: config.SNP}, }, { name: "Service Error", - req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))}, + req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: config.SNP}, expectedErr: true, }, } @@ -153,9 +154,9 @@ func TestAttestationEndpoint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.name == svcErr { - svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, errors.New("")).Once() + svc.On("Attestation", context.Background(), tt.req.TeeNonce, tt.req.VtpmNonce, tt.req.AttType).Return([]byte{}, errors.New("")).Once() } else { - svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, nil).Once() + svc.On("Attestation", context.Background(), tt.req.TeeNonce, tt.req.VtpmNonce, tt.req.AttType).Return([]byte{}, nil).Once() } endpoint := attestationEndpoint(svc) res, err := endpoint(context.Background(), tt.req) diff --git a/agent/api/grpc/requests.go b/agent/api/grpc/requests.go index 4d7d2206..1d5e2b5d 100644 --- a/agent/api/grpc/requests.go +++ b/agent/api/grpc/requests.go @@ -4,6 +4,10 @@ package grpc import ( "errors" + + config "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) type algoReq struct { @@ -38,9 +42,16 @@ func (req resultReq) validate() error { } type attestationReq struct { - ReportData [64]byte + TeeNonce [quoteprovider.Nonce]byte + VtpmNonce [vtpm.Nonce]byte + AttType config.AttestationType } func (req attestationReq) validate() error { - return nil + switch req.AttType { + case config.SNP, config.VTPM, config.SNPvTPM: + return nil + default: + return errors.New("invalid attestation type in attestation request") + } } diff --git a/agent/api/grpc/server.go b/agent/api/grpc/server.go index 2ac9a387..3bf19a7c 100644 --- a/agent/api/grpc/server.go +++ b/agent/api/grpc/server.go @@ -11,6 +11,9 @@ import ( "github.com/go-kit/kit/transport/grpc" "github.com/ultravioletrs/cocos/agent" + config "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -21,6 +24,11 @@ const ( FileSizeKey = "file-size" ) +var ( + ErrTEENonceLength = errors.New("malformed report data, expect less or equal to 64 bytes") + ErrVTpmNonceLength = errors.New("malformed vTPM nonce, expect less or equal to 32 bytes") +) + var _ agent.AgentServiceServer = (*grpcServer)(nil) type grpcServer struct { @@ -96,10 +104,20 @@ func encodeResultResponse(_ context.Context, response interface{}) (interface{}, func decodeAttestationRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { req := grpcReq.(*agent.AttestationRequest) - if len(req.ReportData) != agent.ReportDataSize { - return nil, errors.New("malformed report data, expect 64 bytes") + var reportData [quoteprovider.Nonce]byte + var nonce [vtpm.Nonce]byte + + if len(req.TeeNonce) > quoteprovider.Nonce { + return nil, ErrTEENonceLength } - return attestationReq{ReportData: [agent.ReportDataSize]byte(req.ReportData)}, nil + + if len(req.VtpmNonce) > vtpm.Nonce { + return nil, ErrVTpmNonceLength + } + + copy(reportData[:], req.TeeNonce) + copy(nonce[:], req.VtpmNonce) + return attestationReq{TeeNonce: reportData, VtpmNonce: nonce, AttType: config.AttestationType(req.Type)}, nil } func encodeAttestationResponse(_ context.Context, response interface{}) (interface{}, error) { diff --git a/agent/api/grpc/server_test.go b/agent/api/grpc/server_test.go index 8f07db99..5521fd15 100644 --- a/agent/api/grpc/server_test.go +++ b/agent/api/grpc/server_test.go @@ -11,6 +11,9 @@ import ( "github.com/stretchr/testify/mock" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/mocks" + config "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -151,10 +154,12 @@ func TestAttestation(t *testing.T) { mockStream := &MockAgentService_AttestationServer{ctx: context.Background()} mockStream.On("Send", mock.AnythingOfType("*agent.AttestationResponse")).Return(nil) - reportData := [agent.ReportDataSize]byte{} - mockService.On("Attestation", mock.Anything, reportData).Return([]byte("attestation data"), nil) + reportData := [quoteprovider.Nonce]byte{} + vtpmNonce := [vtpm.Nonce]byte{} + attestationType := config.SNP + mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return([]byte("attestation data"), nil) - err := server.Attestation(&agent.AttestationRequest{ReportData: reportData[:]}, mockStream) + err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:]}, mockStream) assert.NoError(t, err) mockService.AssertExpectations(t) @@ -199,11 +204,11 @@ func TestEncodeResultResponse(t *testing.T) { } func TestDecodeAttestationRequest(t *testing.T) { - reportData := [agent.ReportDataSize]byte{} - req := &agent.AttestationRequest{ReportData: reportData[:]} + nonce := [quoteprovider.Nonce]byte{} + req := &agent.AttestationRequest{TeeNonce: nonce[:]} decoded, err := decodeAttestationRequest(context.Background(), req) assert.NoError(t, err) - assert.Equal(t, attestationReq{ReportData: reportData}, decoded) + assert.Equal(t, attestationReq{TeeNonce: nonce}, decoded) } func TestEncodeAttestationResponse(t *testing.T) { diff --git a/agent/api/logging.go b/agent/api/logging.go index 60f65d5f..f9435be9 100644 --- a/agent/api/logging.go +++ b/agent/api/logging.go @@ -13,6 +13,9 @@ import ( "time" "github.com/ultravioletrs/cocos/agent" + config "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) var _ agent.Service = (*loggingMiddleware)(nil) @@ -103,7 +106,7 @@ func (lm *loggingMiddleware) Result(ctx context.Context) (response []byte, err e return lm.svc.Result(ctx) } -func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [agent.ReportDataSize]byte) (response []byte, err error) { +func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) (response []byte, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method Attestation took %s to complete", time.Since(begin)) if err != nil { @@ -113,5 +116,5 @@ func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [agent. lm.logger.Info(fmt.Sprintf("%s without errors", message)) }(time.Now()) - return lm.svc.Attestation(ctx, reportData) + return lm.svc.Attestation(ctx, reportData, nonce, attType) } diff --git a/agent/api/metrics.go b/agent/api/metrics.go index de4fdb91..8b9403ab 100644 --- a/agent/api/metrics.go +++ b/agent/api/metrics.go @@ -12,6 +12,9 @@ import ( "github.com/go-kit/kit/metrics" "github.com/ultravioletrs/cocos/agent" + config "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) var _ agent.Service = (*metricsMiddleware)(nil) @@ -89,11 +92,11 @@ func (ms *metricsMiddleware) Result(ctx context.Context) ([]byte, error) { return ms.svc.Result(ctx) } -func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [agent.ReportDataSize]byte) ([]byte, error) { +func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) ([]byte, error) { defer func(begin time.Time) { ms.counter.With("method", "attestation").Add(1) ms.latency.With("method", "attestation").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.Attestation(ctx, reportData) + return ms.svc.Attestation(ctx, reportData, nonce, attType) } diff --git a/agent/cvms/server/cvm.go b/agent/cvms/server/cvm.go index fadca8f7..fa4d33e9 100644 --- a/agent/cvms/server/cvm.go +++ b/agent/cvms/server/cvm.go @@ -71,7 +71,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error return err } - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() if err != nil { as.logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error())) return err diff --git a/agent/mocks/agent.go b/agent/mocks/agent.go index fc717fdd..945f03af 100644 --- a/agent/mocks/agent.go +++ b/agent/mocks/agent.go @@ -6,9 +6,10 @@ package mocks import ( - context "context" - agent "github.com/ultravioletrs/cocos/agent" + config "github.com/ultravioletrs/cocos/pkg/attestation" + + context "context" mock "github.com/stretchr/testify/mock" ) @@ -73,9 +74,9 @@ func (_c *Service_Algo_Call) RunAndReturn(run func(context.Context, agent.Algori return _c } -// Attestation provides a mock function with given fields: ctx, reportData -func (_m *Service) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) { - ret := _m.Called(ctx, reportData) +// Attestation provides a mock function with given fields: ctx, reportData, nonce, attType +func (_m *Service) Attestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType config.AttestationType) ([]byte, error) { + ret := _m.Called(ctx, reportData, nonce, attType) if len(ret) == 0 { panic("no return value specified for Attestation") @@ -83,19 +84,19 @@ func (_m *Service) Attestation(ctx context.Context, reportData [64]byte) ([]byte var r0 []byte var r1 error - if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok { - return rf(ctx, reportData) + if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, config.AttestationType) ([]byte, error)); ok { + return rf(ctx, reportData, nonce, attType) } - if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok { - r0 = rf(ctx, reportData) + if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, config.AttestationType) []byte); ok { + r0 = rf(ctx, reportData, nonce, attType) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) } } - if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok { - r1 = rf(ctx, reportData) + if rf, ok := ret.Get(1).(func(context.Context, [64]byte, [32]byte, config.AttestationType) error); ok { + r1 = rf(ctx, reportData, nonce, attType) } else { r1 = ret.Error(1) } @@ -111,13 +112,15 @@ type Service_Attestation_Call struct { // Attestation is a helper method to define mock.On call // - ctx context.Context // - reportData [64]byte -func (_e *Service_Expecter) Attestation(ctx interface{}, reportData interface{}) *Service_Attestation_Call { - return &Service_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData)} +// - nonce [32]byte +// - attType config.AttestationType +func (_e *Service_Expecter) Attestation(ctx interface{}, reportData interface{}, nonce interface{}, attType interface{}) *Service_Attestation_Call { + return &Service_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, nonce, attType)} } -func (_c *Service_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte)) *Service_Attestation_Call { +func (_c *Service_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType config.AttestationType)) *Service_Attestation_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([64]byte)) + run(args[0].(context.Context), args[1].([64]byte), args[2].([32]byte), args[3].(config.AttestationType)) }) return _c } @@ -127,7 +130,7 @@ func (_c *Service_Attestation_Call) Return(_a0 []byte, _a1 error) *Service_Attes return _c } -func (_c *Service_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte) ([]byte, error)) *Service_Attestation_Call { +func (_c *Service_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, [32]byte, config.AttestationType) ([]byte, error)) *Service_Attestation_Call { _c.Call.Return(run) return _c } diff --git a/agent/service.go b/agent/service.go index 40871b0f..15d460c2 100644 --- a/agent/service.go +++ b/agent/service.go @@ -23,6 +23,9 @@ import ( "github.com/ultravioletrs/cocos/agent/events" "github.com/ultravioletrs/cocos/agent/statemachine" "github.com/ultravioletrs/cocos/internal" + config "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "golang.org/x/crypto/sha3" ) @@ -69,8 +72,6 @@ const ( ) const ( - // ReportDataSize is the size of the report data expected by the attestation service. - ReportDataSize = 64 algoFilePermission = 0o700 ) @@ -99,6 +100,8 @@ var ( ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers") // ErrAttestationFailed attestation failed. ErrAttestationFailed = errors.New("failed to get raw quote") + // ErrAttType indicates that the attestation type that is requested does not exist or is not supported. + ErrAttestationType = errors.New("attestation type does not exist or is not supported") ) // Service specifies an API that must be fullfiled by the domain service @@ -109,28 +112,29 @@ type Service interface { Algo(ctx context.Context, algorithm Algorithm) error Data(ctx context.Context, dataset Dataset) error Result(ctx context.Context) ([]byte, error) - Attestation(ctx context.Context, reportData [ReportDataSize]byte) ([]byte, error) + Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) ([]byte, error) State() string } type agentService struct { mu sync.Mutex - computation Computation // Holds the current computation request details. - algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. - result []byte // Stores the result of the computation. - sm statemachine.StateMachine // Manages the state transitions of the agent service. - runError error // Stores any error encountered during the computation run. - eventSvc events.Service // Service for publishing events related to computation. - quoteProvider client.QuoteProvider // Provider for generating attestation quotes. - logger *slog.Logger // Logger for the agent service. - resultsConsumed bool // Indicates if the results have been consumed. - cancel context.CancelFunc // Cancels the computation context. + computation Computation // Holds the current computation request details. + algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. + result []byte // Stores the result of the computation. + sm statemachine.StateMachine // Manages the state transitions of the agent service. + runError error // Stores any error encountered during the computation run. + eventSvc events.Service // Service for publishing events related to computation. + quoteProvider client.LeveledQuoteProvider // Provider for generating attestation quotes. + logger *slog.Logger // Logger for the agent service. + resultsConsumed bool // Indicates if the results have been consumed. + cancel context.CancelFunc // Cancels the computation context. + vmpl int // VMPL at which the Agent is running. } var _ Service = (*agentService)(nil) // New instantiates the agent service implementation. -func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quoteProvider client.QuoteProvider) Service { +func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quoteProvider client.LeveledQuoteProvider, vmlp int) Service { sm := statemachine.NewStateMachine(Idle) ctx, cancel := context.WithCancel(ctx) svc := &agentService{ @@ -139,6 +143,7 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quot quoteProvider: quoteProvider, logger: logger, cancel: cancel, + vmpl: vmlp, } transitions := []statemachine.Transition{ @@ -397,13 +402,29 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) { return as.result, as.runError } -func (as *agentService) Attestation(ctx context.Context, reportData [ReportDataSize]byte) ([]byte, error) { - rawQuote, err := as.quoteProvider.GetRawQuote(reportData) - if err != nil { - return []byte{}, err +func (as *agentService) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) ([]byte, error) { + switch attType { + case config.SNP: + rawQuote, err := as.quoteProvider.GetRawQuoteAtLevel(reportData, uint(as.vmpl)) + if err != nil { + return []byte{}, err + } + return rawQuote, nil + case config.VTPM: + vTPMQuote, err := vtpm.Attest(reportData[:], nonce[:], false) + if err != nil { + return []byte{}, err + } + return vTPMQuote, nil + case config.SNPvTPM: + vTPMQuote, err := vtpm.Attest(reportData[:], nonce[:], true) + if err != nil { + return []byte{}, err + } + return vTPMQuote, nil + default: + return []byte{}, ErrAttestationType } - - return rawQuote, nil } func (as *agentService) runComputation(state statemachine.State) { diff --git a/agent/service_test.go b/agent/service_test.go index 48e21a88..e521d8c3 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -22,6 +22,7 @@ import ( smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" mocks2 "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "golang.org/x/crypto/sha3" "google.golang.org/grpc/metadata" ) @@ -35,7 +36,7 @@ var ( const datasetFile = "iris.csv" func TestAlgo(t *testing.T) { - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() require.NoError(t, err) algo, err := os.ReadFile(algoPath) @@ -120,7 +121,7 @@ func TestAlgo(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() - svc := New(ctx, mglog.NewMock(), events, qp) + svc := New(ctx, mglog.NewMock(), events, qp, 0) err := svc.InitComputation(ctx, testComputation(t)) require.NoError(t, err) @@ -139,7 +140,7 @@ func TestAlgo(t *testing.T) { } func TestData(t *testing.T) { - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() require.NoError(t, err) algo, err := os.ReadFile(algoPath) @@ -215,7 +216,7 @@ func TestData(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() - svc := New(ctx, mglog.NewMock(), events, qp) + svc := New(ctx, mglog.NewMock(), events, qp, 0) err := svc.InitComputation(ctx, testComputation(t)) require.NoError(t, err) @@ -240,7 +241,7 @@ func TestData(t *testing.T) { } func TestResult(t *testing.T) { - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() require.NoError(t, err) cases := []struct { @@ -323,23 +324,26 @@ func TestResult(t *testing.T) { } func TestAttestation(t *testing.T) { - qp := new(mocks2.QuoteProvider) + qp := new(mocks2.LeveledQuoteProvider) cases := []struct { name string - reportData [ReportDataSize]byte + reportData [quoteprovider.Nonce]byte + nonce [vtpm.Nonce]byte rawQuote []uint8 err error }{ { name: "Test attestation successful", reportData: generateReportData(), + nonce: [32]byte{}, rawQuote: make([]uint8, 0), err: nil, }, { name: "Test attestation failed", reportData: generateReportData(), + nonce: [32]byte{}, rawQuote: nil, err: ErrAttestationFailed, }, @@ -355,22 +359,22 @@ func TestAttestation(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() - getQuote := qp.On("GetRawQuote", mock.Anything).Return(tc.rawQuote, tc.err) + getQuote := qp.On("GetRawQuoteAtLevel", mock.Anything, mock.Anything).Return(tc.rawQuote, tc.err) if tc.err != ErrAttestationFailed { - getQuote = qp.On("GetRawQuote", mock.Anything).Return(tc.reportData, nil) + getQuote = qp.On("GetRawQuoteAtLevel", mock.Anything, mock.Anything).Return(tc.nonce, nil) } defer getQuote.Unset() - svc := New(ctx, mglog.NewMock(), events, qp) + svc := New(ctx, mglog.NewMock(), events, qp, 0) time.Sleep(300 * time.Millisecond) - _, err := svc.Attestation(ctx, tc.reportData) + _, err := svc.Attestation(ctx, tc.reportData, tc.nonce, 0) assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) }) } } -func generateReportData() [ReportDataSize]byte { - bytes := make([]byte, ReportDataSize) +func generateReportData() [quoteprovider.Nonce]byte { + bytes := make([]byte, quoteprovider.Nonce) _, err := rand.Read(bytes) if err != nil { log.Fatalf("Failed to generate random bytes: %v", err) diff --git a/attestation.bin b/attestation.bin index 364fa960..5b858195 100644 Binary files a/attestation.bin and b/attestation.bin differ diff --git a/cli/attestation.go b/cli/attestation.go index 38762919..36232fbb 100644 --- a/cli/attestation.go +++ b/cli/attestation.go @@ -25,8 +25,9 @@ import ( "github.com/google/go-tpm/legacy/tpm2" "github.com/spf13/cobra" "github.com/spf13/pflag" - "github.com/ultravioletrs/cocos/agent" + config "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" @@ -109,33 +110,37 @@ const ( } } ` + SNP = "snp" + VTPM = "vtpm" + SNPvTPM = "snp-vtpm" ) var ( - mode string - cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} - cfgString string - timeout time.Duration - maxRetryDelay time.Duration - platformInfo string - stepping string - trustedAuthorKeys []string - trustedAuthorHashes []string - trustedIdKeys []string - trustedIdKeyHashes []string - attestationFile string - tpmAttestationFile string - attestation []byte - empty16 = [size16]byte{} - empty32 = [size32]byte{} - empty64 = [size64]byte{} - defaultReportIdMa = []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255} - getJsonAttestation bool - errReportSize = errors.New("attestation contents too small") - output string - nonce []byte - format string - teeNonce []byte + mode string + cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} + cfgString string + timeout time.Duration + maxRetryDelay time.Duration + platformInfo string + stepping string + trustedAuthorKeys []string + trustedAuthorHashes []string + trustedIdKeys []string + trustedIdKeyHashes []string + attestationFile string + tpmAttestationFile string + attestation []byte + empty16 = [size16]byte{} + empty32 = [size32]byte{} + empty64 = [size64]byte{} + defaultReportIdMa = []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255} + errReportSize = errors.New("attestation contents too small") + ErrBadAttestation = errors.New("attestation file is corrupted or in wrong format") + output string + nonce []byte + format string + teeNonce []byte + getTextProtoAttestation bool ) var errEmptyFile = errors.New("input file is empty") @@ -178,31 +183,75 @@ func (cli *CLI) NewAttestationCmd() *cobra.Command { func (cli *CLI) NewGetAttestationCmd() *cobra.Command { cmd := &cobra.Command{ - Use: "get", - Short: "Retrieve attestation information from agent. Report data expected in hex enoded string of length 64 bytes.", - Example: "get ", - Args: cobra.ExactArgs(1), + Use: "get", + Short: "Retrieve attestation information from agent. The argument of the command must be the type of the report (snp or vtpm or snp-vtpm).", + ValidArgs: []cobra.Completion{SNP, VTPM, SNPvTPM}, + Example: fmt.Sprintf(`Based on attestation report type: + get %s --tee <512 bit hex value> + get %s --vtpm <256 bit hex value> + get %s --tee <512 bit hex value> --vtpm <256 bit hex value>`, SNP, VTPM, SNPvTPM), + Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { if cli.connectErr != nil { printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) return } - cmd.Println("Getting attestation") - - reportData, err := hex.DecodeString(args[0]) - if err != nil { - printError(cmd, "Error decoding report data: %v ❌ ", err) + if err := cobra.OnlyValidArgs(cmd, args); err != nil { + printError(cmd, "Bad attestation type: %v ❌ ", err) return } - if len(reportData) != agent.ReportDataSize { - msg := color.New(color.FgRed).Sprintf("report data must be a hex encoded string of length %d bytes ❌ ", agent.ReportDataSize) + + attestationType := args[0] + + attType := config.SNP + switch attestationType { + case SNP: + cmd.Println("Fetching SEV-SNP attestation report") + case VTPM: + cmd.Println("Fetching vTPM report") + attType = config.VTPM + case SNPvTPM: + cmd.Println("Fetching SEV-SNP and vTPM report") + attType = config.SNPvTPM + } + + if (attType == config.VTPM || attType == config.SNPvTPM) && len(nonce) == 0 { + msg := color.New(color.FgRed).Sprint("vTPM nonce must be defined for vTPM attestation ❌ ") cmd.Println(msg) return } + if (attType == config.SNP || attType == config.SNPvTPM) && len(teeNonce) == 0 { + msg := color.New(color.FgRed).Sprint("TEE nonce must be defined for SEV-SNP attestation ❌ ") + cmd.Println(msg) + return + } + + var fixedReportData [quoteprovider.Nonce]byte + if attType != config.VTPM { + if len(teeNonce) > quoteprovider.Nonce { + msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", quoteprovider.Nonce) + cmd.Println(msg) + return + } + + copy(fixedReportData[:], teeNonce) + } + + var fixedVtpmNonceByte [vtpm.Nonce]byte + if attType != config.SNP { + if len(nonce) > vtpm.Nonce { + msg := color.New(color.FgRed).Sprintf("vTPM nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", vtpm.Nonce) + cmd.Println(msg) + return + } + + copy(fixedVtpmNonceByte[:], nonce) + } + filename := attestationFilePath - if getJsonAttestation { + if getTextProtoAttestation { filename = attestationJson } @@ -212,7 +261,7 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { return } - if err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData), attestationFile); err != nil { + if err := cli.agentSDK.Attestation(cmd.Context(), fixedReportData, fixedVtpmNonceByte, int(attType), attestationFile); err != nil { printError(cmd, "Failed to get attestation due to error: %v ❌ ", err) return } @@ -222,16 +271,32 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { return } - if getJsonAttestation { + if getTextProtoAttestation { result, err := os.ReadFile(filename) if err != nil { printError(cmd, "Error reading attestation file: %v ❌ ", err) return } - result, err = attesationToJSON(result) + switch attestationType { + case SNP: + result, err = attesationToJSON(result) + case VTPM, SNPvTPM: + marshalOptions := prototext.MarshalOptions{ + Multiline: true, + EmitASCII: true, + } + var attvTPM tpmAttest.Attestation + err = proto.Unmarshal(result, &attvTPM) + if err != nil { + printError(cmd, "failed to unmarshal the attestation report: %v ❌ ", ErrBadAttestation) + } + + result = []byte(marshalOptions.Format(&attvTPM)) + } + if err != nil { - printError(cmd, "Error converting attestation to json: %v ❌ ", err) + printError(cmd, "Error converting attestation to textproto: %v ❌ ", err) return } @@ -245,7 +310,9 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { }, } - cmd.Flags().BoolVarP(&getJsonAttestation, "json", "j", false, "Get attestation in json format") + cmd.Flags().BoolVarP(&getTextProtoAttestation, "textproto", "p", false, "Get attestation in textproto format") + cmd.Flags().BytesHexVarP(&teeNonce, "tee", "e", []byte{}, "Define the nonce for the SNP attestation report (must be used with attestation type snp and snp-vtpm)") + cmd.Flags().BytesHexVarP(&nonce, "vtpm", "t", []byte{}, "Define the nonce for the vTPM attestation report (must be used with attestation type vtpm and snp-vtpm)") return cmd } @@ -585,7 +652,12 @@ func sevsnpverify(cmd *cobra.Command, args []string) error { return fmt.Errorf("error validating input: %v ❌ ", err) } - if err := quoteprovider.VerifyAndValidate(attestation, &cfg); err != nil { + attestationPB, err := abi.ReportCertsToProto(attestation) + if err != nil { + return fmt.Errorf("failed to convert attestation bytes to struct %v ❌ ", err) + } + + if err := quoteprovider.VerifyAndValidate(attestationPB, &cfg); err != nil { return fmt.Errorf("attestation validation and verification failed with error: %v ❌ ", err) } cmd.Println("Attestation validation and verification is successful!") diff --git a/cli/attestation_test.go b/cli/attestation_test.go index e8542568..9678efa6 100644 --- a/cli/attestation_test.go +++ b/cli/attestation_test.go @@ -18,7 +18,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/sdk/mocks" ) @@ -35,8 +36,8 @@ func TestNewAttestationCmd(t *testing.T) { cmd.SetOutput(&buf) - reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize) - mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData), mock.Anything).Return(nil) + reportData := bytes.Repeat([]byte{0x01}, quoteprovider.Nonce) + mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(reportData), mock.Anything).Return(nil) cmd.SetArgs([]string{hex.EncodeToString(reportData)}) err := cmd.Execute() @@ -47,6 +48,10 @@ func TestNewAttestationCmd(t *testing.T) { func TestNewGetAttestationCmd(t *testing.T) { validattestation, err := os.ReadFile("../attestation.bin") require.NoError(t, err) + + teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)) + vtpmNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce)) + testCases := []struct { name string args []string @@ -56,57 +61,85 @@ func TestNewGetAttestationCmd(t *testing.T) { expectedOut string }{ { - name: "successful attestation retrieval", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))}, + name: "successful SNP attestation retrieval", + args: []string{"snp", "--tee", teeNonce}, mockResponse: []byte("mock attestation"), mockError: nil, expectedOut: "Attestation result retrieved and saved successfully!", }, { - name: "invalid report data (decoding error)", - args: []string{"invalid"}, - mockResponse: nil, - mockError: errors.New("error"), - expectedErr: "Error decoding report data", + name: "successful vTPM attestation retrieval", + args: []string{"vtpm", "--vtpm", vtpmNonce}, + mockResponse: []byte("mock attestation"), + mockError: nil, + expectedOut: "Attestation result retrieved and saved successfully!", + }, + { + name: "successful SNP-vTPM attestation retrieval", + args: []string{"snp-vtpm", "--tee", teeNonce, "--vtpm", vtpmNonce}, + mockResponse: []byte("mock attestation"), + mockError: nil, + expectedOut: "Attestation result retrieved and saved successfully!", + }, + { + name: "missing vTPM nonce", + args: []string{"snp-vtpm", "--tee", teeNonce}, + mockResponse: []byte("mock attestation"), + mockError: nil, + expectedOut: "vTPM nonce must be defined for vTPM attestation", + }, + { + name: "missing TEE nonce", + args: []string{"snp-vtpm", "--vtpm", vtpmNonce}, + mockResponse: []byte("mock attestation"), + mockError: nil, + expectedOut: "TEE nonce must be defined for SEV-SNP attestation", }, { name: "invalid report data size", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, 32))}, + args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 65))}, mockResponse: nil, mockError: errors.New("error"), - expectedErr: "report data must be a hex encoded string of length 64 bytes", + expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes", }, { - name: "invalid report data hex", + name: "invalid vTPM data size", + args: []string{"vtpm", "-t", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 33))}, + mockResponse: nil, + mockError: errors.New("error"), + expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes", + }, + { + name: "invalid arguments", args: []string{"invalid"}, mockResponse: nil, mockError: errors.New("error"), - expectedErr: "Error decoding report data", + expectedErr: "Bad attestation type: invalid argument ", }, { name: "failed to get attestation", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))}, + args: []string{"snp", "-e", teeNonce}, mockResponse: nil, mockError: errors.New("error"), expectedErr: "Failed to get attestation due to error", }, { - name: "JSON report error", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), "--json"}, + name: "Textproto report error", + args: []string{"snp", "-e", teeNonce, "--textproto"}, mockResponse: []byte("mock attestation"), mockError: nil, - expectedErr: "Error converting attestation to json", + expectedErr: "Error converting attestation to textproto", }, { - name: "successful JSON report", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), "--json"}, + name: "successful Textproto report", + args: []string{"snp", "-e", teeNonce, "--textproto"}, mockResponse: validattestation, mockError: nil, expectedOut: "Attestation result retrieved and saved successfully!", }, { name: "connection error", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))}, + args: []string{"snp", "-e", teeNonce}, mockResponse: nil, mockError: errors.New("failed to connect to agent"), expectedErr: "Failed to connect to agent", @@ -128,8 +161,8 @@ func TestNewGetAttestationCmd(t *testing.T) { var buf bytes.Buffer cmd.SetOutput(&buf) - mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { - _, err := args.Get(2).(*os.File).Write(tc.mockResponse) + mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { + _, err := args.Get(4).(*os.File).Write(tc.mockResponse) require.NoError(t, err) }) diff --git a/cli/cache.go b/cli/cache.go index dfe5dfd0..3a174f62 100644 --- a/cli/cache.go +++ b/cli/cache.go @@ -12,7 +12,7 @@ import ( "github.com/google/go-sev-guest/proto/check" "github.com/google/go-sev-guest/verify/trust" "github.com/spf13/cobra" - "github.com/ultravioletrs/cocos/pkg/clients/grpc" + config "github.com/ultravioletrs/cocos/pkg/attestation" ) const ( @@ -27,14 +27,14 @@ func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command { Example: "ca-bundle ", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - attestationConfiguration := check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} - err := grpc.ReadAttestationPolicy(args[0], &attestationConfiguration) + attestationConfiguration := config.Config{SnpCheck: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &config.PcrConfig{}} + err := config.ReadAttestationPolicy(args[0], &attestationConfiguration) if err != nil { printError(cmd, "Error while reading manifest: %v ❌ ", err) return } - product := attestationConfiguration.RootOfTrust.ProductLine + product := attestationConfiguration.SnpCheck.RootOfTrust.ProductLine getter := trust.DefaultHTTPSGetter() caURL := kds.ProductCertChainURL(abi.VcekReportSigner, product) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 419805a7..3eb57a2a 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -4,6 +4,7 @@ package main import ( "context" + "errors" "fmt" "log" "log/slog" @@ -16,6 +17,7 @@ import ( "github.com/absmach/magistrala/pkg/prometheus" "github.com/caarlos0/env/v11" "github.com/google/go-sev-guest/client" + "github.com/stretchr/testify/mock" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/api" "github.com/ultravioletrs/cocos/agent/cvms" @@ -24,6 +26,7 @@ import ( "github.com/ultravioletrs/cocos/agent/events" agentlogger "github.com/ultravioletrs/cocos/internal/logger" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks" pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc" cvmsgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm" "golang.org/x/sync/errgroup" @@ -39,6 +42,7 @@ const ( type config struct { LogLevel string `env:"AGENT_LOG_LEVEL" envDefault:"debug"` + Vmpl int `env:"AGENT_VMPL" envDefault:"2"` } func main() { @@ -72,11 +76,20 @@ func main() { return } - qp, err := quoteprovider.GetQuoteProvider() - if err != nil { - logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error())) - exitCode = 1 - return + var qp client.LeveledQuoteProvider + + if !sevGuesDeviceExists() { + logger.Info("SEV-SNP device not found") + qpMock := new(mocks.LeveledQuoteProvider) + qpMock.On("GetRawQuoteAtLevel", mock.Anything, mock.Anything).Return([]uint8{}, errors.New("SEV-SNP device not found")) + qp = qpMock + } else { + qp, err = quoteprovider.GetLeveledQuoteProvider() + if err != nil { + logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error())) + exitCode = 1 + return + } } cvmGrpcConfig := pkggrpc.CVMClientConfig{} @@ -111,7 +124,13 @@ func main() { return } - svc := newService(ctx, logger, eventSvc, qp) + if cfg.Vmpl < 0 || cfg.Vmpl > 3 { + logger.Error("vmpl level must be in a range [0, 3]") + exitCode = 1 + return + } + + svc := newService(ctx, logger, eventSvc, qp, cfg.Vmpl) if err := os.MkdirAll(storageDir, 0o755); err != nil { logger.Error(fmt.Sprintf("failed to create storage directory: %s", err)) @@ -150,8 +169,8 @@ func main() { } } -func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, qp client.QuoteProvider) agent.Service { - svc := agent.New(ctx, logger, eventSvc, qp) +func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, qp client.LeveledQuoteProvider, vmpl int) agent.Service { + svc := agent.New(ctx, logger, eventSvc, qp, vmpl) svc = api.LoggingMiddleware(svc, logger) counter, latency := prometheus.MakeMetrics(svcName, "api") @@ -159,3 +178,12 @@ func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Servic return svc } + +func sevGuesDeviceExists() bool { + d, err := client.OpenDevice() + if err != nil { + return false + } + d.Close() + return true +} diff --git a/cocos-manager.env b/cocos-manager.env index 8d4b6932..07ccb962 100644 --- a/cocos-manager.env +++ b/cocos-manager.env @@ -49,13 +49,13 @@ MANAGER_QEMU_BIN_PATH=qemu-system-x86_64 MANAGER_QEMU_USE_SUDO=true MANAGER_QEMU_ENABLE_SEV=false MANAGER_QEMU_ENABLE_SEV_SNP=false +MANAGER_QEMU_IGVM_FILE=/etc/cocos/coconut-qemu.igvm MANAGER_QEMU_ENABLE_KVM=true MANAGER_QEMU_MACHINE=q35 MANAGER_QEMU_CPU=EPYC MANAGER_QEMU_SMP_COUNT=4 MANAGER_QEMU_SMP_MAXCPUS=16 MANAGER_QEMU_MEM_ID=ram1 -MANAGER_QEMU_KERNEL_HASH=false MANAGER_QEMU_NO_GRAPHIC=true MANAGER_QEMU_MONITOR=pty MANAGER_QEMU_HOST_FWD_RANGE=6100-6200 diff --git a/go.mod b/go.mod index 172f9e1b..ab4cde91 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid v4.4.0+incompatible github.com/google/go-sev-guest v0.12.1 - github.com/google/go-tdx-guest v0.3.1 // indirect + github.com/google/go-tdx-guest v0.3.2-0.20241009005452-097ee70d0843 // indirect github.com/mdlayher/vsock v1.2.1 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 @@ -33,8 +33,11 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gofrs/uuid/v5 v5.3.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.5.4 // indirect github.com/google/certificate-transparency-go v1.1.2 // indirect - github.com/google/go-attestation v0.5.0 // indirect + github.com/google/gce-tcb-verifier v0.2.3-0.20240905212129-12f728a62786 // indirect + github.com/google/go-attestation v0.5.1 // indirect + github.com/google/go-eventlog v0.0.2-0.20241003021507-01bb555f7cba // indirect github.com/google/go-tspi v0.3.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -49,6 +52,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.32.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.32.0 // indirect go.opentelemetry.io/otel/sdk v1.32.0 // indirect + golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect gotest.tools/v3 v3.5.1 // indirect ) @@ -61,7 +65,7 @@ require ( github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/google/go-configfs-tsm v0.2.2 // indirect + github.com/google/go-configfs-tsm v0.3.3-0.20240919001351-b4b5b84fdcbc // indirect github.com/google/go-tpm v0.9.3 github.com/google/go-tpm-tools v0.4.4 github.com/google/logger v1.1.1 @@ -89,3 +93,5 @@ require ( ) replace github.com/virtee/sev-snp-measure-go => github.com/sammyoina/sev-snp-measure-go v0.0.0-20241202151803-ef189f0ff825 + +replace github.com/google/go-tpm-tools => github.com/danko-miladinovic/go-tpm-tools v0.0.0-20250228160324-1ebcfd79567c diff --git a/go.sum b/go.sum index 34c4e973..647c6ede 100644 --- a/go.sum +++ b/go.sum @@ -188,6 +188,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6N github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/danko-miladinovic/go-tpm-tools v0.0.0-20250228160324-1ebcfd79567c h1:gFo8kqRXFoM6ttqMrK+M3xffxco+Yj80kUo3NoMe8LU= +github.com/danko-miladinovic/go-tpm-tools v0.0.0-20250228160324-1ebcfd79567c/go.mod h1:ktjTNq8yZFD6TzdBFefUfen96rF3NpYwpSb2d8bc+Y8= github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -333,8 +335,10 @@ github.com/google/certificate-transparency-go v1.1.2-0.20210422104406-9f33727a7a github.com/google/certificate-transparency-go v1.1.2-0.20210512142713-bed466244fa6/go.mod h1:aF2dp7Dh81mY8Y/zpzyXps4fQW5zQbDu2CxfpJB6NkI= github.com/google/certificate-transparency-go v1.1.2 h1:4hE0GEId6NAW28dFpC+LrRGwQX5dtmXQGDbg8+/MZOM= github.com/google/certificate-transparency-go v1.1.2/go.mod h1:3OL+HKDqHPUfdKrHVQxO6T8nDLO0HF7LRTlkIWXaWvQ= -github.com/google/go-attestation v0.5.0 h1:jXtAWT2sw2Yu8mYU0BC7FDidR+ngxFPSE+pl6IUu3/0= -github.com/google/go-attestation v0.5.0/go.mod h1:0Tik9y3rzV649Jcr7evbljQHQAsIlJucyqQjYDBqktU= +github.com/google/gce-tcb-verifier v0.2.3-0.20240905212129-12f728a62786 h1:1ijRI0+jsZCl3CqeJG3Cib6w+wYCBlD/rWRo5a+ZME4= +github.com/google/gce-tcb-verifier v0.2.3-0.20240905212129-12f728a62786/go.mod h1:Jvv9i6JF1t7sDVW09zP2x+9vN3lcujtih2Zb/lVXaLs= +github.com/google/go-attestation v0.5.1 h1:jqtOrLk5MNdliTKjPbIPrAaRKJaKW+0LIU2n/brJYms= +github.com/google/go-attestation v0.5.1/go.mod h1:KqGatdUhg5kPFkokyzSBDxwSCFyRgIgtRkMp6c3lOBQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -349,8 +353,10 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-configfs-tsm v0.2.2 h1:YnJ9rXIOj5BYD7/0DNnzs8AOp7UcvjfTvt215EWcs98= -github.com/google/go-configfs-tsm v0.2.2/go.mod h1:EL1GTDFMb5PZQWDviGfZV9n87WeGTR/JUg13RfwkgRo= +github.com/google/go-configfs-tsm v0.3.3-0.20240919001351-b4b5b84fdcbc h1:SG12DWUUM5igxm+//YX5Yq4vhdoRnOG9HkCodkOn+YU= +github.com/google/go-configfs-tsm v0.3.3-0.20240919001351-b4b5b84fdcbc/go.mod h1:EL1GTDFMb5PZQWDviGfZV9n87WeGTR/JUg13RfwkgRo= +github.com/google/go-eventlog v0.0.2-0.20241003021507-01bb555f7cba h1:05m5+kgZjxYUZrx3bZfkKHl6wkch+Khao6N21rFHInk= +github.com/google/go-eventlog v0.0.2-0.20241003021507-01bb555f7cba/go.mod h1:7huE5P8w2NTObSwSJjboHmB7ioBNblkijdzoVa2skfQ= github.com/google/go-github/v28 v28.1.1/go.mod h1:bsqJWQX05omyWVmc00nEUql9mhQyv38lDZ8kPZcQVoM= github.com/google/go-licenses v0.0.0-20210329231322-ce1d9163b77d/go.mod h1:+TYOmkVoJOpwnS0wfdsJCV9CoD5nJYsHoFk/0CrTK4M= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= @@ -358,12 +364,10 @@ github.com/google/go-replayers/grpcreplay v0.1.0/go.mod h1:8Ig2Idjpr6gifRd6pNVgg github.com/google/go-replayers/httpreplay v0.1.0/go.mod h1:YKZViNhiGgqdBlUbI2MwGpq4pXxNmhJLPHQ7cv2b5no= github.com/google/go-sev-guest v0.12.1 h1:H4rFYnPIn8HtqEsNTmh56Zxcf9BI9n48ZSYCnpYLYvc= github.com/google/go-sev-guest v0.12.1/go.mod h1:SK9vW+uyfuzYdVN0m8BShL3OQCtXZe/JPF7ZkpD3760= -github.com/google/go-tdx-guest v0.3.1 h1:gl0KvjdsD4RrJzyLefDOvFOUH3NAJri/3qvaL5m83Iw= -github.com/google/go-tdx-guest v0.3.1/go.mod h1:/rc3d7rnPykOPuY8U9saMyEps0PZDThLk/RygXm04nE= +github.com/google/go-tdx-guest v0.3.2-0.20241009005452-097ee70d0843 h1:+MoPobRN9HrDhGyn6HnF5NYo4uMBKaiFqAtf/D/OB4A= +github.com/google/go-tdx-guest v0.3.2-0.20241009005452-097ee70d0843/go.mod h1:g/n8sKITIT9xRivBUbizo34DTsUm2nN2uU3A662h09g= github.com/google/go-tpm v0.9.3 h1:+yx0/anQuGzi+ssRqeD6WpXjW2L/V0dItUayO0i9sRc= github.com/google/go-tpm v0.9.3/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= -github.com/google/go-tpm-tools v0.4.4 h1:oiQfAIkc6xTy9Fl5NKTeTJkBTlXdHsxAofmQyxBKY98= -github.com/google/go-tpm-tools v0.4.4/go.mod h1:T8jXkp2s+eltnCDIsXR84/MTcVU9Ja7bh3Mit0pa4AY= github.com/google/go-tspi v0.3.0 h1:ADtq8RKfP+jrTyIWIZDIYcKOMecRqNJFOew2IT0Inus= github.com/google/go-tspi v0.3.0/go.mod h1:xfMGI3G0PhxCdNVcYr1C4C+EizojDg/TXuX5by8CiHI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -898,6 +902,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/hal/linux/configs/cocos_defconfig b/hal/linux/configs/cocos_defconfig index 42dd193e..0162c941 100644 --- a/hal/linux/configs/cocos_defconfig +++ b/hal/linux/configs/cocos_defconfig @@ -27,7 +27,7 @@ BR2_ROOTFS_POST_SCRIPT_ARGS="$(BR2_DEFCONFIG)" # Linux headers same as kernel BR2_PACKAGE_HOST_LINUX_HEADERS_CUSTOM_6_11=y BR2_TOOLCHAIN_HEADERS_LATEST=y -BR2_TOOLCHAIN_HEADERS_AT_LEAST="6.12-rc6" +BR2_TOOLCHAIN_HEADERS_AT_LEAST="6.11-rc7" # Kernel BR2_LINUX_KERNEL=y diff --git a/internal/server/grpc/grpc.go b/internal/server/grpc/grpc.go index e7009b7e..bff7ac49 100644 --- a/internal/server/grpc/grpc.go +++ b/internal/server/grpc/grpc.go @@ -25,6 +25,7 @@ import ( "github.com/ultravioletrs/cocos/agent/auth" "github.com/ultravioletrs/cocos/internal/server" "github.com/ultravioletrs/cocos/pkg/atls" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -51,7 +52,7 @@ type Server struct { server.BaseServer server *grpc.Server registerService serviceRegister - quoteProvider client.QuoteProvider + quoteProvider client.LeveledQuoteProvider authSvc auth.Authenticator health *health.Server } @@ -60,7 +61,7 @@ type serviceRegister func(srv *grpc.Server) var _ server.Server = (*Server)(nil) -func New(ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, registerService serviceRegister, logger *slog.Logger, qp client.QuoteProvider, authSvc auth.Authenticator) server.Server { +func New(ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, registerService serviceRegister, logger *slog.Logger, qp client.LeveledQuoteProvider, authSvc auth.Authenticator) server.Server { base := config.GetBaseConfig() listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port) return &Server{ @@ -301,5 +302,19 @@ func generateCertificatesForATLS() ([]byte, []byte, error) { Bytes: privateKeyBytes, }) + cert, err := x509.ParseCertificate(certDERBytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal public key to DER format: %w", err) + } + + if err := vtpm.ExtendPCR(vtpm.PCR15, pubKeyDER); err != nil { + return nil, nil, fmt.Errorf("failed to extend vTPM PCR with public key: %w", err) + } + return certBytes, keyBytes, nil } diff --git a/internal/server/grpc/grpc_test.go b/internal/server/grpc/grpc_test.go index b5bcbaeb..3934f039 100644 --- a/internal/server/grpc/grpc_test.go +++ b/internal/server/grpc/grpc_test.go @@ -22,6 +22,7 @@ import ( authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks" "github.com/ultravioletrs/cocos/internal/server" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc" "google.golang.org/grpc/test/bufconn" ) @@ -30,6 +31,28 @@ const bufSize = 1024 * 1024 var lis *bufconn.Listener +type DummyRWC struct{} + +// Read fills p with byte(len(p)) and returns len(p). +func (l *DummyRWC) Read(p []byte) (int, error) { + n := len(p) + // Fill each byte in p with the value of n as a byte. + for i := range p { + p[i] = byte(n) + } + return n, nil +} + +// Write simply returns len(p) indicating that all bytes were written. +func (l *DummyRWC) Write(p []byte) (int, error) { + // In this simple implementation, we ignore the data. + return len(p), nil +} + +func (l *DummyRWC) Close() error { + return nil +} + func init() { lis = bufconn.Listen(bufSize) } @@ -47,7 +70,7 @@ func TestNew(t *testing.T) { }, } logger := slog.Default() - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) @@ -97,7 +120,7 @@ func TestServerStartWithTLSFile(t *testing.T) { logBuffer := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) @@ -144,7 +167,7 @@ func TestServerStartWithmTLSFile(t *testing.T) { logBuffer := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) @@ -184,7 +207,7 @@ func TestServerStop(t *testing.T) { } buf := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) @@ -259,6 +282,8 @@ func (b *ThreadSafeBuffer) String() string { } func TestServerInitializationAndStartup(t *testing.T) { + vtpm.ExternalTPM = &DummyRWC{} + testCases := []struct { name string config server.AgentConfig @@ -374,7 +399,7 @@ func TestServerInitializationAndStartup(t *testing.T) { logBuffer := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, qp, authSvc) diff --git a/manager/README.md b/manager/README.md index 5f089064..f37fe3a5 100644 --- a/manager/README.md +++ b/manager/README.md @@ -44,7 +44,10 @@ The service is configured using the environment variables from the following tab | MANAGER_QEMU_SEV_ID | The ID for the Secure Encrypted Virtualization (SEV) device. | sev0 | | MANAGER_QEMU_SEV_CBITPOS | The position of the C-bit in the physical address. | 51 | | MANAGER_QEMU_SEV_REDUCED_PHYS_BITS | The number of reduced physical address bits for SEV. | 1 | +| MANAGER_QEMU_ENABLE_HOST_DATA | Enable additional data for the SEV host. | false | | MANAGER_QEMU_HOST_DATA | Additional data for the SEV host. | | +| MANAGER_QEMU_IGVM_ID | The ID of the IGVM file. | igvm0 | +| MANAGER_QEMU_IGVM_FILE | The file path to the IGVM file. | /root/coconut-qemu.igvm | | MANAGER_QEMU_VSOCK_ID | The ID for the virtual socket device. | vhost-vsock-pci0 | | MANAGER_QEMU_VSOCK_GUEST_CID | The guest-side CID (Context ID) for the virtual socket device. | 3 | | MANAGER_QEMU_VSOCK_VNC | Whether to enable the virtual socket device for VNC. | 0 | @@ -58,7 +61,6 @@ The service is configured using the environment variables from the following tab | MANAGER_QEMU_SMP_COUNT | The number of virtual CPUs. | 4 | | MANAGER_QEMU_SMP_MAXCPUS | The maximum number of virtual CPUs. | 64 | | MANAGER_QEMU_MEM_ID | The ID for the memory device. | ram1 | -| MANAGER_QEMU_KERNEL_HASH | Whether to enable kernel hash verification. | false | | MANAGER_QEMU_NO_GRAPHIC | Whether to disable the graphical display. | true | | MANAGER_QEMU_MONITOR | The type of monitor to use. | pty | | MANAGER_QEMU_HOST_FWD_RANGE | The range of host ports to forward. | 6100-6200 | @@ -232,21 +234,7 @@ MANAGER_QEMU_ENABLE_SEV=false \ MANAGER_QEMU_ENABLE_SEV_SNP=true \ MANAGER_QEMU_SEV_CBITPOS=51 \ MANAGER_QEMU_BIN_PATH= \ -MANAGER_QEMU_QEMU_OVMF_CODE_FILE= \ -./build/cocos-manager -``` - -To include the kernel hash into the measurement of the attestation report (SEV or SEV-SNP), start manager like this - -```sh -MANAGER_GRPC_URL=localhost:7001 \ -MANAGER_LOG_LEVEL=debug \ -MANAGER_QEMU_ENABLE_SEV=false \ -MANAGER_QEMU_ENABLE_SEV_SNP=true \ -MANAGER_QEMU_SEV_CBITPOS=51 \ -MANAGER_QEMU_KERNEL_HASH=true \ -MANAGER_QEMU_BIN_PATH= \ -MANAGER_QEMU_QEMU_OVMF_CODE_FILE= \ +MANAGER_QEMU_IGVM_FILE= \ ./build/cocos-manager ``` diff --git a/manager/attestation_policy.go b/manager/attestation_policy.go index 12d74c1e..ce5fe016 100644 --- a/manager/attestation_policy.go +++ b/manager/attestation_policy.go @@ -76,8 +76,8 @@ func (ms *managerService) FetchAttestationPolicy(_ context.Context, computationI attestationPolicy.Policy.Measurement = measurement } - if vmi.Config.HostData != "" { - hostData, err := base64.StdEncoding.DecodeString(vmi.Config.HostData) + if vmi.Config.SevConfig.EnableHostData { + hostData, err := base64.StdEncoding.DecodeString(vmi.Config.SevConfig.HostData) if err != nil { return nil, err } diff --git a/manager/attestation_policy_test.go b/manager/attestation_policy_test.go index 54f87b03..a5e9230a 100644 --- a/manager/attestation_policy_test.go +++ b/manager/attestation_policy_test.go @@ -57,9 +57,10 @@ func TestFetchAttestationPolicy(t *testing.T) { binaryBehavior: "success", vmConfig: qemu.VMInfo{ Config: qemu.Config{ - EnableSEV: true, - SMPCount: 2, - CPU: "EPYC", + EnableSEV: true, + EnableSEVSNP: false, + SMPCount: 2, + CPU: "EPYC", OVMFCodeConfig: qemu.OVMFCodeConfig{ File: "/path/to/OVMF_CODE.fd", }, @@ -68,23 +69,6 @@ func TestFetchAttestationPolicy(t *testing.T) { }, expectedError: "open /path/to/OVMF_CODE.fd: no such file or directory", }, - { - name: "Valid SEV-SNP configuration", - computationId: "sev-snp-computation", - binaryBehavior: "success", - vmConfig: qemu.VMInfo{ - Config: qemu.Config{ - EnableSEVSNP: true, - SMPCount: 4, - CPU: "EPYC-v2", - OVMFCodeConfig: qemu.OVMFCodeConfig{ - File: "/path/to/OVMF_CODE_SNP.fd", - }, - }, - LaunchTCB: 0, - }, - expectedError: "open /path/to/OVMF_CODE_SNP.fd: no such file or director", - }, { name: "Invalid computation ID", computationId: "non-existent", diff --git a/manager/manager.pb.go b/manager/manager.pb.go index a1a6da7f..34f4aec2 100644 --- a/manager/manager.pb.go +++ b/manager/manager.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.0 -// protoc v5.29.0 +// protoc-gen-go v1.36.4 +// protoc v5.29.3 // source: manager/manager.proto package manager @@ -15,6 +15,7 @@ import ( emptypb "google.golang.org/protobuf/types/known/emptypb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -422,7 +423,7 @@ func (x *SVMInfoReq) GetId() string { var File_manager_manager_proto protoreflect.FileDescriptor -var file_manager_manager_proto_rawDesc = []byte{ +var file_manager_manager_proto_rawDesc = string([]byte{ 0x0a, 0x15, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, @@ -488,16 +489,16 @@ var file_manager_manager_proto_rawDesc = []byte{ 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x52, 0x65, 0x73, 0x22, 0x00, 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x2f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +}) var ( file_manager_manager_proto_rawDescOnce sync.Once - file_manager_manager_proto_rawDescData = file_manager_manager_proto_rawDesc + file_manager_manager_proto_rawDescData []byte ) func file_manager_manager_proto_rawDescGZIP() []byte { file_manager_manager_proto_rawDescOnce.Do(func() { - file_manager_manager_proto_rawDescData = protoimpl.X.CompressGZIP(file_manager_manager_proto_rawDescData) + file_manager_manager_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_manager_manager_proto_rawDesc), len(file_manager_manager_proto_rawDesc))) }) return file_manager_manager_proto_rawDescData } @@ -538,7 +539,7 @@ func file_manager_manager_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_manager_manager_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_manager_manager_proto_rawDesc), len(file_manager_manager_proto_rawDesc)), NumEnums: 0, NumMessages: 7, NumExtensions: 0, @@ -549,7 +550,6 @@ func file_manager_manager_proto_init() { MessageInfos: file_manager_manager_proto_msgTypes, }.Build() File_manager_manager_proto = out.File - file_manager_manager_proto_rawDesc = nil file_manager_manager_proto_goTypes = nil file_manager_manager_proto_depIdxs = nil } diff --git a/manager/manager_grpc.pb.go b/manager/manager_grpc.pb.go index b8111ce4..77a8fd26 100644 --- a/manager/manager_grpc.pb.go +++ b/manager/manager_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.0 +// - protoc v5.29.3 // source: manager/manager.proto package manager diff --git a/manager/qemu/config.go b/manager/qemu/config.go index 86267353..fcc0c3d4 100644 --- a/manager/qemu/config.go +++ b/manager/qemu/config.go @@ -56,7 +56,13 @@ type SevConfig struct { ID string `env:"SEV_ID" envDefault:"sev0"` CBitPos int `env:"SEV_CBITPOS" envDefault:"51"` ReducedPhysBits int `env:"SEV_REDUCED_PHYS_BITS" envDefault:"1"` - HostData string `env:"HOST_DATA" envDefault:""` + EnableHostData bool `env:"ENABLE_HOST_DATA" envDefault:"false"` + HostData string `env:"HOST_DATA" envDefault:""` +} + +type IGVMConfig struct { + ID string `env:"IGVM_ID" envDefault:"igvm0"` + File string `env:"IGVM_FILE" envDefault:"/root/coconut-qemu.igvm"` } type VSockConfig struct { @@ -80,9 +86,6 @@ type Config struct { MemID string `env:"MEM_ID" envDefault:"ram1"` MemoryConfig - // Kernel hash - KernelHash bool `env:"KERNEL_HASH" envDefault:"false"` - // OVMF OVMFCodeConfig OVMFVarsConfig @@ -100,6 +103,9 @@ type Config struct { // SEV SevConfig + // vTPM + IGVMConfig + // display NoGraphic bool `env:"NO_GRAPHIC" envDefault:"true"` Monitor string `env:"MONITOR" envDefault:"pty"` @@ -173,40 +179,39 @@ func (config Config) ConstructQemuArgs() []string { // SEV if config.EnableSEV || config.EnableSEVSNP { sevType := "sev-guest" - kernelHash := "" hostData := "" args = append(args, "-machine", - fmt.Sprintf("confidential-guest-support=%s,memory-backend=%s", + fmt.Sprintf("confidential-guest-support=%s,memory-backend=%s,igvm-cfg=%s", config.SevConfig.ID, - config.MemID)) + config.MemID, + config.IGVMConfig.ID)) if config.EnableSEVSNP { - args = append(args, "-bios", config.OVMFCodeConfig.File) sevType = "sev-snp-guest" - if config.SevConfig.HostData != "" { + if config.SevConfig.EnableHostData { hostData = fmt.Sprintf(",host-data=%s", config.SevConfig.HostData) } } - if config.KernelHash { - kernelHash = ",kernel-hashes=on" - } - args = append(args, "-object", fmt.Sprintf("memory-backend-memfd,id=%s,size=%s,share=true,prealloc=false", config.MemID, config.MemoryConfig.Size)) args = append(args, "-object", - fmt.Sprintf("%s,id=%s,cbitpos=%d,reduced-phys-bits=%d%s%s", + fmt.Sprintf("%s,id=%s,cbitpos=%d,reduced-phys-bits=%d%s", sevType, config.SevConfig.ID, config.SevConfig.CBitPos, config.SevConfig.ReducedPhysBits, - kernelHash, hostData)) + + args = append(args, "-object", + fmt.Sprintf("igvm-cfg,id=%s,file=%s", + config.IGVMConfig.ID, + config.IGVMConfig.File)) } args = append(args, "-kernel", config.DiskImgConfig.KernelFile) diff --git a/manager/qemu/config_test.go b/manager/qemu/config_test.go index 67d4d2af..e21e2124 100644 --- a/manager/qemu/config_test.go +++ b/manager/qemu/config_test.go @@ -132,6 +132,10 @@ func TestConstructQemuArgs(t *testing.T) { CBitPos: 51, ReducedPhysBits: 1, }, + IGVMConfig: IGVMConfig{ + ID: "igvm0", + File: "/test/path/cocos-igvm.igvm", + }, NoGraphic: true, Monitor: "pty", }, @@ -144,10 +148,10 @@ func TestConstructQemuArgs(t *testing.T) { "-netdev", "user,id=vmnic,hostfwd=tcp::7020-:7002", "-device", "virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,addr=0x2,romfile=", "-device", "vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3", - "-machine", "confidential-guest-support=sev0,memory-backend=ram1", - "-bios", "/usr/share/OVMF/OVMF_CODE.fd", + "-machine", "confidential-guest-support=sev0,memory-backend=ram1,igvm-cfg=igvm0", "-object", "memory-backend-memfd,id=ram1,size=2048M,share=true,prealloc=false", "-object", "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1", + "-object", "igvm-cfg,id=igvm0,file=/test/path/cocos-igvm.igvm", "-kernel", "img/bzImage", "-append", "\"quiet console=null\"", "-initrd", "img/rootfs.cpio.gz", @@ -167,37 +171,6 @@ func TestConstructQemuArgs(t *testing.T) { } } -func TestConstructQemuArgs_KernelHash(t *testing.T) { - config := Config{ - EnableSEVSNP: true, - KernelHash: true, - SevConfig: SevConfig{ - ID: "sev0", - CBitPos: 51, - ReducedPhysBits: 1, - }, - } - - result := config.ConstructQemuArgs() - - expected := "-object" - expectedValue := "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1,kernel-hashes=on" - - found := false - for i, arg := range result { - if arg == expected && i+1 < len(result) { - if result[i+1] == expectedValue { - found = true - break - } - } - } - - if !found { - t.Errorf("ConstructQemuArgs() did not contain expected SEV-SNP configuration with kernel hashes enabled") - } -} - func TestConstructQemuArgs_HostData(t *testing.T) { config := Config{ EnableSEVSNP: true, @@ -205,6 +178,7 @@ func TestConstructQemuArgs_HostData(t *testing.T) { ID: "sev0", CBitPos: 51, ReducedPhysBits: 1, + EnableHostData: true, HostData: "test-host-data", }, } diff --git a/manager/qemu/vm.go b/manager/qemu/vm.go index c41e48c0..2c55689f 100644 --- a/manager/qemu/vm.go +++ b/manager/qemu/vm.go @@ -59,7 +59,7 @@ func (v *qemuVM) Start() (err error) { v.vmi.Config.NetDevConfig.ID = fmt.Sprintf("%s-%s", v.vmi.Config.NetDevConfig.ID, id) v.vmi.Config.SevConfig.ID = fmt.Sprintf("%s-%s", v.vmi.Config.SevConfig.ID, id) - if !v.vmi.Config.KernelHash { + if !v.vmi.Config.EnableSEVSNP { // Copy firmware vars file. srcFile := v.vmi.Config.OVMFVarsConfig.File dstFile := fmt.Sprintf("%s/%s-%s.fd", tmpDir, firmwareVars, id) diff --git a/manager/service.go b/manager/service.go index c5bdcc47..eb1832d7 100644 --- a/manager/service.go +++ b/manager/service.go @@ -20,9 +20,9 @@ import ( "github.com/google/uuid" "github.com/ultravioletrs/cocos/manager/qemu" "github.com/ultravioletrs/cocos/manager/vm" + config "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/manager" "golang.org/x/crypto/sha3" - "google.golang.org/protobuf/encoding/protojson" ) const ( @@ -166,14 +166,14 @@ func (ms *managerService) CreateVM(ctx context.Context, req *CreateReq) (string, return "", id, errors.Wrap(ErrFailedToReadPolicy, err) } - var attestationPolicy check.Config + attestationPolicy := config.Config{SnpCheck: &check.Config{RootOfTrust: &check.RootOfTrust{}, Policy: &check.Policy{}}, PcrConfig: &config.PcrConfig{}} - if err = protojson.Unmarshal(f, &attestationPolicy); err != nil { + if err = config.ReadAttestationPolicyFromByte(f, &attestationPolicy); err != nil { return "", id, errors.Wrap(ErrUnmarshalFailed, err) } // Define the TCB that was present at launch of the VM. - cfg.LaunchTCB = attestationPolicy.Policy.MinimumLaunchTcb + cfg.LaunchTCB = attestationPolicy.SnpCheck.Policy.MinimumLaunchTcb } agentPort, err := getFreePort(ms.portRangeMin, ms.portRangeMax) diff --git a/mockery.yml b/mockery.yml index 87f0bd08..e9a304f5 100644 --- a/mockery.yml +++ b/mockery.yml @@ -107,7 +107,7 @@ packages: mockname: "{{.InterfaceName}}" github.com/google/go-sev-guest/client: interfaces: - QuoteProvider: + LeveledQuoteProvider: config: dir: "./pkg/attestation/quoteprovider/mocks" filename: "QuoteProvider.go" diff --git a/pkg/atls/atlsListener.go b/pkg/atls/atlsListener.go index 0e059294..1b627919 100644 --- a/pkg/atls/atlsListener.go +++ b/pkg/atls/atlsListener.go @@ -20,8 +20,8 @@ import ( "unsafe" "github.com/absmach/magistrala/pkg/errors" - "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) const ( @@ -51,8 +51,8 @@ var ( errConnCreate = errors.New("could not create connection") ) -type ValidationVerification func(data1, data2 []byte) error -type FetchAttestation func(data1 []byte) ([]byte, error) +type ValidationVerification func(data1, data2, data3, data4 []byte) error +type FetchAttestation func(data1, data2, data3 []byte) ([]byte, error) func registerFetchAttestation(callback FetchAttestation) uintptr { handle := cgo.NewHandle(callback) @@ -70,7 +70,7 @@ func validationVerificationCallback(teeType C.int) uintptr { case NoTee: return uintptr(0) case AmdSevSnp: - return registerValidationVerification(quoteprovider.VerifyAttestationReportTLS) + return registerValidationVerification(vtpm.VTPMVerify) default: return uintptr(0) } @@ -82,22 +82,24 @@ func fetchAttestationCallback(teeType C.int) uintptr { case NoTee: return uintptr(0) case AmdSevSnp: - return registerFetchAttestation(quoteprovider.FetchAttestation) + return registerFetchAttestation(vtpm.FetchATLSQuote) default: return uintptr(0) } } //export callVerificationValidationCallback -func callVerificationValidationCallback(callbackHandle uintptr, attReport *C.uchar, attReportSize C.int, repData *C.uchar) C.int { +func callVerificationValidationCallback(callbackHandle uintptr, pubKey *C.uchar, pubKeyLen C.int, quote *C.uchar, quoteSize C.int, teeNonce *C.uchar, nonce *C.uchar) C.int { handle := cgo.Handle(callbackHandle) defer handle.Delete() callback := handle.Value().(ValidationVerification) - attestationReport := C.GoBytes(unsafe.Pointer(attReport), attReportSize) - reportData := C.GoBytes(unsafe.Pointer(repData), agent.ReportDataSize) + pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen) + attestationReport := C.GoBytes(unsafe.Pointer(quote), quoteSize) + teeData := C.GoBytes(unsafe.Pointer(teeNonce), quoteprovider.Nonce) + nonceData := C.GoBytes(unsafe.Pointer(nonce), vtpm.Nonce) - err := callback(attestationReport, reportData) + err := callback(attestationReport, pubKeyCert, teeData, nonceData) if err != nil { fmt.Fprintf(os.Stderr, "callback failed %v", err) return C.int(-1) @@ -107,20 +109,22 @@ func callVerificationValidationCallback(callbackHandle uintptr, attReport *C.uch } //export callFetchAttestationCallback -func callFetchAttestationCallback(callbackHandle uintptr, reportDataByte *C.uchar, outlen *C.int) *C.uchar { +func callFetchAttestationCallback(callbackHandle uintptr, pubKey *C.uchar, pubKeyLen C.int, teeNonceByte *C.uchar, vTPMNonceByte *C.uchar, outlen *C.ulong) *C.uchar { handle := cgo.Handle(callbackHandle) defer handle.Delete() callback := handle.Value().(FetchAttestation) - reportData := C.GoBytes(unsafe.Pointer(reportDataByte), agent.ReportDataSize) + pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen) + teeNonceData := C.GoBytes(unsafe.Pointer(teeNonceByte), quoteprovider.Nonce) + vTPMNonce := C.GoBytes(unsafe.Pointer(vTPMNonceByte), vtpm.Nonce) - quote, err := callback(reportData) + quote, err := callback(pubKeyCert, teeNonceData, vTPMNonce) if err != nil { fmt.Fprintf(os.Stderr, "attestation callback returned nil") return nil } - *outlen = C.int(len(quote)) + *outlen = C.ulong(len(quote)) resultC := C.malloc(C.size_t(len(quote))) if resultC == nil { fmt.Fprintf(os.Stderr, "could not allocate memory for fetch attestation callback") @@ -232,23 +236,23 @@ func (c *ATLSConn) Read(b []byte) (int, error) { case noError: return n, nil // no error. case errorZeroReturn: - fmt.Fprintf(os.Stdout, "Connection closed by peer") + fmt.Fprintf(os.Stdout, "Connection closed by peer\n") return 0, io.EOF // connection closed. case errorWantRead: - fmt.Fprintf(os.Stderr, "Operation read incomplete, retry later") + fmt.Fprintf(os.Stderr, "Operation read incomplete, retry later\n") return 0, nil // non-fatal, just retry later. case errorWantWrite: - fmt.Fprintf(os.Stderr, "Operation write incomplete, retry later") + fmt.Fprintf(os.Stderr, "Operation write incomplete, retry later\n") return 0, nil // non-fatal, just retry later. case errorSyscall: - fmt.Fprintf(os.Stderr, "I/O error") + fmt.Fprintf(os.Stderr, "I/O error\n") return 0, syscall.ECONNRESET // return connection reset error. case errorSsl: - fmt.Fprintf(os.Stderr, "I/O error") + fmt.Fprintf(os.Stderr, "I/O error\n") return 0, syscall.ECONNRESET // return connection reset error. default: fmt.Fprintf(os.Stderr, "SSL error occurred: %d\n", errCode) - return 0, fmt.Errorf("SSL error") + return 0, fmt.Errorf("SSL error\n") } } @@ -297,7 +301,7 @@ func (c *ATLSConn) Close() error { return errTLSConn } else if int(ret) == 1 { c.tlsConn = nil - break; + break } } diff --git a/pkg/atls/extensions.c b/pkg/atls/extensions.c index 6e27f82d..d6f1908c 100644 --- a/pkg/atls/extensions.c +++ b/pkg/atls/extensions.c @@ -7,30 +7,27 @@ #include #include -extern int callVerificationValidationCallback(uintptr_t callbackHandle, const u_char* attReport, int attReportSize, const u_char* repData); -extern u_char* callFetchAttestationCallback(uintptr_t callbackHandle, const u_char* reportDataByte, int* outlen); +extern int callVerificationValidationCallback(uintptr_t callbackHandle, const u_char* pubKey, int pubKeyLen, const u_char* quote, int quoteSize, const u_char* teeNonce, const u_char* nonce); +extern u_char* callFetchAttestationCallback(uintptr_t callbackHandle, const u_char* pubKey, int pubKeyLen, const u_char* teeNonceByte, const u_char* vTPMNonceByte, unsigned long* outlen); extern uintptr_t validationVerificationCallback(int teeType); extern uintptr_t fetchAttestationCallback(int teeType); -int triggerVerificationValidationCallback(uintptr_t callbackHandle, u_char *attestationReport, int reportSize, u_char *reportData) { - if (attestationReport == NULL || reportData == NULL) { - fprintf(stderr, "attestation data and report data cannot be NULL\n"); +int triggerVerificationValidationCallback(uintptr_t callbackHandle, u_char* pub_key, int pub_key_len, u_char *quote, int quote_size, u_char *tee_nonce, u_char *vtpm_nonce) { + if (quote == NULL || vtpm_nonce == NULL || tee_nonce == NULL || pub_key == NULL) { + fprintf(stderr, "attestation and noce and public key cannot be NULL\n"); return -1; } - - return callVerificationValidationCallback(callbackHandle, attestationReport, reportSize, reportData); + return callVerificationValidationCallback(callbackHandle, pub_key, pub_key_len, quote, quote_size, tee_nonce, vtpm_nonce); } -u_char* triggerFetchAttestationCallback(uintptr_t callbackHandle, char *reportData) { - int outlen = REPORT_DATA_SIZE; - - if(reportData == NULL) { +u_char* triggerFetchAttestationCallback(uintptr_t callback_handle, u_char* pub_key, int pub_key_len, char *tee_nonce, char *vtpm_nonce, unsigned long *outlen) { + if(tee_nonce == NULL || vtpm_nonce == NULL) { fprintf(stderr, "Report data cannot be NULL"); return NULL; } - return callFetchAttestationCallback(callbackHandle, reportData, &outlen); + return callFetchAttestationCallback(callback_handle, pub_key, pub_key_len, tee_nonce, vtpm_nonce, outlen); } int check_sev_snp() { @@ -47,46 +44,6 @@ int check_sev_snp() { return 1; } -int compute_sha256_of_public_key_nonce(X509 *cert, u_char *nonce, u_char *hash) { - EVP_PKEY *pkey = NULL; - u_char *pubkey_buf = NULL; - u_char *concatinated = NULL; - int pubkey_len = 0; - int totla_len = 0; - - pkey = X509_get_pubkey(cert); - if (pkey == NULL) { - fprintf(stderr, "Failed to extract public key from certificate\n"); - return 0; - } - - pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf); - if (pubkey_len <= 0) { - fprintf(stderr, "Failed to convert public key to DER format\n"); - EVP_PKEY_free(pkey); - return -1; - } - - totla_len = pubkey_len + CLIENT_RANDOM_SIZE; - concatinated = (u_char*)malloc(totla_len); - if (concatinated == NULL) { - perror("failed to allocate memory"); - return -1; - } - memcpy(concatinated, nonce, CLIENT_RANDOM_SIZE); - memcpy(concatinated + CLIENT_RANDOM_SIZE, pubkey_buf, pubkey_len); - - // Compute the SHA-512 hash of the DER-encoded public key and the random nonce - SHA512(concatinated, totla_len, hash); - - // Clean up - EVP_PKEY_free(pkey); - OPENSSL_free(pubkey_buf); - free(concatinated); - - return 0; // Success -} - /* Evidence request extension - Contains a random nonce that goes into the attestation report @@ -121,9 +78,14 @@ int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type, } if (ext_data != NULL) { - if (RAND_bytes(ext_data->er.data, CLIENT_RANDOM_SIZE) != 1) { - perror("could not generate random bytes, will use SSL client random"); - SSL_get_client_random(s, ext_data->er.data, CLIENT_RANDOM_SIZE); + if (RAND_bytes(ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE) != 1) { + perror("could not generate random bytes for vtpm nonce, will use SSL client random"); + SSL_get_client_random(s, ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE); + } + + if (RAND_bytes(ext_data->er.tee_nonce, REPORT_DATA_SIZE) != 1) { + perror("could not generate random bytes for tee nonce, will use SSL client random"); + SSL_get_client_random(s, ext_data->er.tee_nonce, REPORT_DATA_SIZE); } } else { fprintf(stderr, "add_arg is NULL\n"); @@ -132,7 +94,8 @@ int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type, return -1; } - memcpy(er->data, ext_data->er.data, CLIENT_RANDOM_SIZE); + memcpy(er->vtpm_nonce, ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE); + memcpy(er->tee_nonce, ext_data->er.tee_nonce, REPORT_DATA_SIZE); er->tee_type = AMD_TEE; ext_data->er.tee_type = AMD_TEE; @@ -201,7 +164,8 @@ int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type, evidence_request *er = (evidence_request*)in; if (ext_data != NULL) { - memcpy(ext_data->er.data, er->data, CLIENT_RANDOM_SIZE); + memcpy(ext_data->er.vtpm_nonce, er->vtpm_nonce, CLIENT_RANDOM_SIZE); + memcpy(ext_data->er.tee_nonce, er->tee_nonce, REPORT_DATA_SIZE); ext_data->er.tee_type = er->tee_type; } else { fprintf(stderr, "parse_arg is NULL\n"); @@ -238,7 +202,7 @@ int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type, /* Attestation Certificate extension - Contains the attestation report - - The attestation report contains the hash of the nonce and the Public Key of the x.509 Agent certificate + - The attestation report contains the hash of the nonce, the Public Key of the x.509 Agent certificate, and the vTPM AK */ void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type, unsigned int context, @@ -263,40 +227,46 @@ int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type, { tls_extension_data *ext_data = (tls_extension_data*)add_arg; if (ext_data != NULL) { - u_char *attestation_report; - u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char)); - - if (hash == NULL) { - perror("could not allocate memory"); - *al = SSL_AD_INTERNAL_ERROR; - return -1; - } + u_char *quote; + size_t len = 0; + EVP_PKEY *pkey = NULL; + u_char *pubkey_buf = NULL; + int pubkey_len = 0; + if (x != NULL) { - int ret = compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash); - if (ret != 0) { - fprintf(stderr, "error while calculating hash\n"); - free(hash); - *al = SSL_AD_INTERNAL_ERROR; + pkey = X509_get_pubkey(x); + if (pkey == NULL) { + fprintf(stderr, "Failed to extract public key from certificate\n"); + return -1; + } + + pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf); + if (pubkey_len <= 0) { + fprintf(stderr, "Failed to convert public key to DER format\n"); + EVP_PKEY_free(pkey); return -1; } } else { fprintf(stderr, "agent certificate must be used for aTLS\n"); - free(hash); *al = SSL_AD_INTERNAL_ERROR; return -1; } - attestation_report = triggerFetchAttestationCallback(ext_data->fetch_attestation_handler, hash); - if (attestation_report == NULL) { + quote = triggerFetchAttestationCallback(ext_data->fetch_attestation_handler, pubkey_buf, pubkey_len, ext_data->er.tee_nonce, ext_data->er.vtpm_nonce, &len); + if (quote == NULL) { fprintf(stderr, "attestation report is NULL\n"); *al = SSL_AD_INTERNAL_ERROR; + EVP_PKEY_free(pkey); + OPENSSL_free(pubkey_buf); return -1; } - free(hash); - *out = attestation_report; - *outlen = ATTESTATION_REPORT_SIZE; + EVP_PKEY_free(pkey); + OPENSSL_free(pubkey_buf); + + *out = quote; + *outlen = len; return 1; } else { fprintf(stderr, "add_arg is NULL\n"); @@ -329,34 +299,41 @@ int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type, tls_extension_data *ext_data = (tls_extension_data*)parse_arg; if (ext_data != NULL) { - char *attestation_report = (char*)malloc(ATTESTATION_REPORT_SIZE*sizeof(char)); - u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char)); + char *quote = (char*)malloc(inlen*sizeof(char)); + EVP_PKEY *pkey = NULL; + u_char *pubkey_buf = NULL; + int pubkey_len = 0; int res = 0; - if (hash == NULL || attestation_report == NULL) { + if (quote == NULL) { perror("could not allocate memory"); - - if (hash != NULL) free(hash); - if (attestation_report != NULL) free(attestation_report); - return 0; } - if (compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash) != 0) { - fprintf(stderr, "calculating hash failed\n"); - free(attestation_report); - free(hash); - return 0; + pkey = X509_get_pubkey(x); + if (pkey == NULL) { + fprintf(stderr, "Failed to extract public key from certificate\n"); + return -1; } + + pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf); + if (pubkey_len <= 0) { + fprintf(stderr, "Failed to convert public key to DER format\n"); + EVP_PKEY_free(pkey); + return -1; + } + memcpy(quote, in, inlen); - memcpy(attestation_report, in, inlen); - - res = triggerVerificationValidationCallback(ext_data->verification_validation_handler, - attestation_report, - ATTESTATION_REPORT_SIZE, - hash); - free(attestation_report); - free(hash); + res = triggerVerificationValidationCallback(ext_data->verification_validation_handler, + pubkey_buf, + pubkey_len, + quote, + inlen, + (u_char*)&ext_data->er.tee_nonce, + (u_char*)&ext_data->er.vtpm_nonce); + free(quote); + EVP_PKEY_free(pkey); + OPENSSL_free(pubkey_buf); if (res != 0) { fprintf(stderr, "verification and validation failed, aborting connection\n"); diff --git a/pkg/atls/extensions.h b/pkg/atls/extensions.h index 5fb0d230..456dbe66 100644 --- a/pkg/atls/extensions.h +++ b/pkg/atls/extensions.h @@ -6,7 +6,6 @@ #define EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE 65 #define ATTESTATION_CERTIFICATE_EXTENSION_TYPE 66 -#define ATTESTATION_REPORT_SIZE 0x4A0 #define REPORT_DATA_SIZE 64 #define CLIENT_RANDOM_SIZE 32 #define TLS_CLIENT_CTX 0 @@ -19,7 +18,8 @@ typedef struct evidence_request { int tee_type; - char data[CLIENT_RANDOM_SIZE]; + char vtpm_nonce[CLIENT_RANDOM_SIZE]; + char tee_nonce[REPORT_DATA_SIZE]; } evidence_request; typedef struct tls_extension_data @@ -32,7 +32,7 @@ typedef struct tls_extension_data typedef struct tls_server_connection { int server_fd; - char* cert; + char* cert; int cert_len; char* key; int key_len; diff --git a/pkg/attestation/config.go b/pkg/attestation/config.go new file mode 100644 index 00000000..c8ba29fa --- /dev/null +++ b/pkg/attestation/config.go @@ -0,0 +1,69 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package config + +import ( + "encoding/json" + "os" + + "github.com/absmach/magistrala/pkg/errors" + "github.com/google/go-sev-guest/proto/check" + "google.golang.org/protobuf/encoding/protojson" +) + +type AttestationType int32 + +const ( + SNP AttestationType = iota + VTPM + SNPvTPM +) + +var ( + AttestationPolicy = Config{SnpCheck: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &PcrConfig{}} + ErrAttestationPolicyOpen = errors.New("failed to open Attestation Policy file") + ErrAttestationPolicyDecode = errors.New("failed to decode Attestation Policy file") + ErrAttestationPolicyMissing = errors.New("failed due to missing Attestation Policy file") +) + +type PcrValues struct { + Sha256 map[string]string `json:"sha256"` + Sha384 map[string]string `json:"sha384"` +} + +type PcrConfig struct { + PCRValues PcrValues `json:"pcr_values"` +} + +type Config struct { + SnpCheck *check.Config + PcrConfig *PcrConfig +} + +func ReadAttestationPolicy(policyPath string, attestationConfiguration *Config) error { + if policyPath != "" { + policyData, err := os.ReadFile(policyPath) + if err != nil { + return errors.Wrap(ErrAttestationPolicyOpen, err) + } + + return ReadAttestationPolicyFromByte(policyData, attestationConfiguration) + } + + return ErrAttestationPolicyMissing +} + +func ReadAttestationPolicyFromByte(policyData []byte, attestationConfiguration *Config) error { + unmarshalOptions := protojson.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true} + + if err := unmarshalOptions.Unmarshal(policyData, attestationConfiguration.SnpCheck); err != nil { + return errors.Wrap(ErrAttestationPolicyDecode, err) + } + + if err := json.Unmarshal(policyData, attestationConfiguration.PcrConfig); err != nil { + return errors.Wrap(ErrAttestationPolicyDecode, err) + } + + return nil +} diff --git a/pkg/attestation/quoteprovider/embed.go b/pkg/attestation/quoteprovider/embed.go index c92f797d..cd839b5a 100644 --- a/pkg/attestation/quoteprovider/embed.go +++ b/pkg/attestation/quoteprovider/embed.go @@ -8,26 +8,24 @@ package quoteprovider import ( "github.com/google/go-sev-guest/client" - "github.com/google/go-sev-guest/proto/check" + "github.com/google/go-sev-guest/proto/sevsnp" pb "github.com/google/go-sev-guest/proto/sevsnp" cocosai "github.com/ultravioletrs/cocos" ) -var ( - AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} -) +const Nonce = 64 -var _ client.QuoteProvider = (*embeddedQuoteProvider)(nil) +var _ client.LeveledQuoteProvider = (*embeddedQuoteProvider)(nil) type embeddedQuoteProvider struct { } -func GetQuoteProvider() (client.QuoteProvider, error) { +func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) { return &embeddedQuoteProvider{}, nil } -// GetQuote returns the SEV quote for the given report data. -func (e *embeddedQuoteProvider) GetRawQuote(reportData [64]byte) ([]byte, error) { +// GetRawQuoteAtLevel returns the SEV quote for the given report data and VMPL. +func (e *embeddedQuoteProvider) GetRawQuoteAtLevel(reportData [64]byte, vmpl uint) ([]byte, error) { return cocosai.EmbeddedAttestation, nil } @@ -46,6 +44,6 @@ func FetchAttestation(reportDataSlice []byte) ([]byte, error) { return cocosai.EmbeddedAttestation, nil } -func VerifyAttestationReportTLS(attestationBytes []byte, reportData []byte) error { +func VerifyAttestationReportTLS(attestation *sevsnp.Attestation, reportData []byte) error { return nil } diff --git a/pkg/attestation/quoteprovider/mocks/QuoteProvider.go b/pkg/attestation/quoteprovider/mocks/QuoteProvider.go index 179e2e01..0c636159 100644 --- a/pkg/attestation/quoteprovider/mocks/QuoteProvider.go +++ b/pkg/attestation/quoteprovider/mocks/QuoteProvider.go @@ -10,42 +10,42 @@ import ( mock "github.com/stretchr/testify/mock" ) -// QuoteProvider is an autogenerated mock type for the QuoteProvider type -type QuoteProvider struct { +// LeveledQuoteProvider is an autogenerated mock type for the LeveledQuoteProvider type +type LeveledQuoteProvider struct { mock.Mock } -type QuoteProvider_Expecter struct { +type LeveledQuoteProvider_Expecter struct { mock *mock.Mock } -func (_m *QuoteProvider) EXPECT() *QuoteProvider_Expecter { - return &QuoteProvider_Expecter{mock: &_m.Mock} +func (_m *LeveledQuoteProvider) EXPECT() *LeveledQuoteProvider_Expecter { + return &LeveledQuoteProvider_Expecter{mock: &_m.Mock} } -// GetRawQuote provides a mock function with given fields: reportData -func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) { - ret := _m.Called(reportData) +// GetRawQuoteAtLevel provides a mock function with given fields: reportData, vmpl +func (_m *LeveledQuoteProvider) GetRawQuoteAtLevel(reportData [64]byte, vmpl uint) ([]uint8, error) { + ret := _m.Called(reportData, vmpl) if len(ret) == 0 { - panic("no return value specified for GetRawQuote") + panic("no return value specified for GetRawQuoteAtLevel") } var r0 []uint8 var r1 error - if rf, ok := ret.Get(0).(func([64]byte) ([]uint8, error)); ok { - return rf(reportData) + if rf, ok := ret.Get(0).(func([64]byte, uint) ([]uint8, error)); ok { + return rf(reportData, vmpl) } - if rf, ok := ret.Get(0).(func([64]byte) []uint8); ok { - r0 = rf(reportData) + if rf, ok := ret.Get(0).(func([64]byte, uint) []uint8); ok { + r0 = rf(reportData, vmpl) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]uint8) } } - if rf, ok := ret.Get(1).(func([64]byte) error); ok { - r1 = rf(reportData) + if rf, ok := ret.Get(1).(func([64]byte, uint) error); ok { + r1 = rf(reportData, vmpl) } else { r1 = ret.Error(1) } @@ -53,36 +53,37 @@ func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) { return r0, r1 } -// QuoteProvider_GetRawQuote_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRawQuote' -type QuoteProvider_GetRawQuote_Call struct { +// LeveledQuoteProvider_GetRawQuoteAtLevel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRawQuoteAtLevel' +type LeveledQuoteProvider_GetRawQuoteAtLevel_Call struct { *mock.Call } -// GetRawQuote is a helper method to define mock.On call +// GetRawQuoteAtLevel is a helper method to define mock.On call // - reportData [64]byte -func (_e *QuoteProvider_Expecter) GetRawQuote(reportData interface{}) *QuoteProvider_GetRawQuote_Call { - return &QuoteProvider_GetRawQuote_Call{Call: _e.mock.On("GetRawQuote", reportData)} +// - vmpl uint +func (_e *LeveledQuoteProvider_Expecter) GetRawQuoteAtLevel(reportData interface{}, vmpl interface{}) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call { + return &LeveledQuoteProvider_GetRawQuoteAtLevel_Call{Call: _e.mock.On("GetRawQuoteAtLevel", reportData, vmpl)} } -func (_c *QuoteProvider_GetRawQuote_Call) Run(run func(reportData [64]byte)) *QuoteProvider_GetRawQuote_Call { +func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) Run(run func(reportData [64]byte, vmpl uint)) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([64]byte)) + run(args[0].([64]byte), args[1].(uint)) }) return _c } -func (_c *QuoteProvider_GetRawQuote_Call) Return(_a0 []uint8, _a1 error) *QuoteProvider_GetRawQuote_Call { +func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) Return(_a0 []uint8, _a1 error) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *QuoteProvider_GetRawQuote_Call) RunAndReturn(run func([64]byte) ([]uint8, error)) *QuoteProvider_GetRawQuote_Call { +func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) RunAndReturn(run func([64]byte, uint) ([]uint8, error)) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call { _c.Call.Return(run) return _c } // IsSupported provides a mock function with given fields: -func (_m *QuoteProvider) IsSupported() bool { +func (_m *LeveledQuoteProvider) IsSupported() bool { ret := _m.Called() if len(ret) == 0 { @@ -99,35 +100,35 @@ func (_m *QuoteProvider) IsSupported() bool { return r0 } -// QuoteProvider_IsSupported_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsSupported' -type QuoteProvider_IsSupported_Call struct { +// LeveledQuoteProvider_IsSupported_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsSupported' +type LeveledQuoteProvider_IsSupported_Call struct { *mock.Call } // IsSupported is a helper method to define mock.On call -func (_e *QuoteProvider_Expecter) IsSupported() *QuoteProvider_IsSupported_Call { - return &QuoteProvider_IsSupported_Call{Call: _e.mock.On("IsSupported")} +func (_e *LeveledQuoteProvider_Expecter) IsSupported() *LeveledQuoteProvider_IsSupported_Call { + return &LeveledQuoteProvider_IsSupported_Call{Call: _e.mock.On("IsSupported")} } -func (_c *QuoteProvider_IsSupported_Call) Run(run func()) *QuoteProvider_IsSupported_Call { +func (_c *LeveledQuoteProvider_IsSupported_Call) Run(run func()) *LeveledQuoteProvider_IsSupported_Call { _c.Call.Run(func(args mock.Arguments) { run() }) return _c } -func (_c *QuoteProvider_IsSupported_Call) Return(_a0 bool) *QuoteProvider_IsSupported_Call { +func (_c *LeveledQuoteProvider_IsSupported_Call) Return(_a0 bool) *LeveledQuoteProvider_IsSupported_Call { _c.Call.Return(_a0) return _c } -func (_c *QuoteProvider_IsSupported_Call) RunAndReturn(run func() bool) *QuoteProvider_IsSupported_Call { +func (_c *LeveledQuoteProvider_IsSupported_Call) RunAndReturn(run func() bool) *LeveledQuoteProvider_IsSupported_Call { _c.Call.Return(run) return _c } // Product provides a mock function with given fields: -func (_m *QuoteProvider) Product() *sevsnp.SevProduct { +func (_m *LeveledQuoteProvider) Product() *sevsnp.SevProduct { ret := _m.Called() if len(ret) == 0 { @@ -146,40 +147,40 @@ func (_m *QuoteProvider) Product() *sevsnp.SevProduct { return r0 } -// QuoteProvider_Product_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Product' -type QuoteProvider_Product_Call struct { +// LeveledQuoteProvider_Product_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Product' +type LeveledQuoteProvider_Product_Call struct { *mock.Call } // Product is a helper method to define mock.On call -func (_e *QuoteProvider_Expecter) Product() *QuoteProvider_Product_Call { - return &QuoteProvider_Product_Call{Call: _e.mock.On("Product")} +func (_e *LeveledQuoteProvider_Expecter) Product() *LeveledQuoteProvider_Product_Call { + return &LeveledQuoteProvider_Product_Call{Call: _e.mock.On("Product")} } -func (_c *QuoteProvider_Product_Call) Run(run func()) *QuoteProvider_Product_Call { +func (_c *LeveledQuoteProvider_Product_Call) Run(run func()) *LeveledQuoteProvider_Product_Call { _c.Call.Run(func(args mock.Arguments) { run() }) return _c } -func (_c *QuoteProvider_Product_Call) Return(_a0 *sevsnp.SevProduct) *QuoteProvider_Product_Call { +func (_c *LeveledQuoteProvider_Product_Call) Return(_a0 *sevsnp.SevProduct) *LeveledQuoteProvider_Product_Call { _c.Call.Return(_a0) return _c } -func (_c *QuoteProvider_Product_Call) RunAndReturn(run func() *sevsnp.SevProduct) *QuoteProvider_Product_Call { +func (_c *LeveledQuoteProvider_Product_Call) RunAndReturn(run func() *sevsnp.SevProduct) *LeveledQuoteProvider_Product_Call { _c.Call.Return(run) return _c } -// NewQuoteProvider creates a new instance of QuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// NewLeveledQuoteProvider creates a new instance of LeveledQuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewQuoteProvider(t interface { +func NewLeveledQuoteProvider(t interface { mock.TestingT Cleanup(func()) -}) *QuoteProvider { - mock := &QuoteProvider{} +}) *LeveledQuoteProvider { + mock := &LeveledQuoteProvider{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/pkg/attestation/quoteprovider/sev.go b/pkg/attestation/quoteprovider/sev.go index fc8bf482..649f5d4e 100644 --- a/pkg/attestation/quoteprovider/sev.go +++ b/pkg/attestation/quoteprovider/sev.go @@ -14,7 +14,6 @@ import ( "time" "github.com/absmach/magistrala/pkg/errors" - "github.com/google/go-sev-guest/abi" "github.com/google/go-sev-guest/client" "github.com/google/go-sev-guest/proto/check" "github.com/google/go-sev-guest/proto/sevsnp" @@ -22,6 +21,7 @@ import ( "github.com/google/go-sev-guest/verify" "github.com/google/go-sev-guest/verify/trust" "github.com/google/logger" + config "github.com/ultravioletrs/cocos/pkg/attestation" "google.golang.org/protobuf/proto" ) @@ -29,20 +29,19 @@ const ( cocosDirectory = ".cocos" caBundleName = "ask_ark.pem" attestationReportSize = 0x4A0 - reportDataSize = 64 + Nonce = 64 sevProductNameMilan = "Milan" sevProductNameGenoa = "Genoa" + sevVMPL = 2 ) var ( - AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} - timeout = time.Minute * 2 - maxTryDelay = time.Second * 30 + timeout = time.Minute * 2 + maxTryDelay = time.Second * 30 ) var ( errProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevProductNameMilan, sevProductNameGenoa)) - errReportSize = errors.New("attestation report size mismatch") errAttVerification = errors.New("attestation verification failed") errAttValidation = errors.New("attestation validation failed") ) @@ -138,38 +137,31 @@ func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error return nil } -func GetQuoteProvider() (client.QuoteProvider, error) { - return client.GetQuoteProvider() +func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) { + return client.GetLeveledQuoteProvider() } -func VerifyAttestationReportTLS(attestationBytes []byte, reportData []byte) error { - config, err := copyConfig(&AttConfigurationSEVSNP) +func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte) error { + config, err := copyConfig(config.AttestationPolicy.SnpCheck) if err != nil { return errors.Wrap(fmt.Errorf("failed to create a copy of attestation policy"), err) } + // Certificate chain is populated based on the extra data that is appended to the SEV-SNP attestation report. + // This data is not part of the attestation report and it will be ignored. + attestationPB.CertificateChain = nil config.Policy.ReportData = reportData[:] - return VerifyAndValidate(attestationBytes, config) + return VerifyAndValidate(attestationPB, config) } -func VerifyAndValidate(attestationReport []byte, cfg *check.Config) error { +func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error { logger.Init("", false, false, io.Discard) - if len(attestationReport) < attestationReportSize { - return errReportSize - } - attestationBytes := attestationReport[:attestationReportSize] - - attestationPB, err := abi.ReportCertsToProto(attestationBytes) - if err != nil { - return fmt.Errorf("failed to convert attestation bytes to struct %v", errors.Wrap(errAttVerification, err)) - } - - if err = verifyReport(attestationPB, cfg); err != nil { + if err := verifyReport(attestationPB, cfg); err != nil { return err } - if err = validateReport(attestationPB, cfg); err != nil { + if err := validateReport(attestationPB, cfg); err != nil { return err } @@ -177,19 +169,19 @@ func VerifyAndValidate(attestationReport []byte, cfg *check.Config) error { } func FetchAttestation(reportDataSlice []byte) ([]byte, error) { - var reportData [reportDataSize]byte + var reportData [Nonce]byte - qp, err := GetQuoteProvider() + qp, err := GetLeveledQuoteProvider() if err != nil { return []byte{}, fmt.Errorf("could not get quote provider") } - if len(reportData) > reportDataSize { + if len(reportData) > Nonce { return []byte{}, fmt.Errorf("attestation report size mismatch") } copy(reportData[:], reportDataSlice) - rawQuote, err := qp.GetRawQuote(reportData) + rawQuote, err := qp.GetRawQuoteAtLevel(reportData, sevVMPL) if err != nil { return []byte{}, fmt.Errorf("failed to get raw quote") } diff --git a/pkg/attestation/quoteprovider/sev_test.go b/pkg/attestation/quoteprovider/sev_test.go index 6f2e4250..d3c687c0 100644 --- a/pkg/attestation/quoteprovider/sev_test.go +++ b/pkg/attestation/quoteprovider/sev_test.go @@ -17,14 +17,10 @@ import ( "github.com/google/go-sev-guest/proto/sevsnp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + config "github.com/ultravioletrs/cocos/pkg/attestation" "google.golang.org/protobuf/encoding/protojson" ) -const ( - measurementOffset = 0x90 - signatureOffset = 0x2A0 -) - func TestFillInAttestationLocal(t *testing.T) { tempDir, err := os.MkdirTemp("", "test_home") require.NoError(t, err) @@ -76,18 +72,18 @@ func TestFillInAttestationLocal(t *testing.T) { } func TestVerifyAttestationReportSuccess(t *testing.T) { - file, reportData := prepareForTestVerifyAttestationReport(t) + attestationPB, reportData := prepVerifyAttReport(t) tests := []struct { name string - attestationReport []byte + attestationReport *sevsnp.Attestation reportData []byte goodProduct int err error }{ { name: "Valid attestation, validation and verification is performed succsessfully", - attestationReport: file, + attestationReport: attestationPB, reportData: reportData, goodProduct: 1, err: nil, @@ -103,20 +99,20 @@ func TestVerifyAttestationReportSuccess(t *testing.T) { } func TestVerifyAttestationReportMalformedSignature(t *testing.T) { - file, reportData := prepareForTestVerifyAttestationReport(t) + attestationPB, reportData := prepVerifyAttReport(t) // Change random data so in the signature so the signature failes - file[signatureOffset] = file[signatureOffset] ^ 0x01 + attestationPB.Report.Signature[0] = attestationPB.Report.Signature[0] ^ 0x01 tests := []struct { name string - attestationReport []byte + attestationReport *sevsnp.Attestation reportData []byte err error }{ { name: "Valid attestation, distorted signature", - attestationReport: file, + attestationReport: attestationPB, reportData: reportData, err: errAttVerification, }, @@ -131,17 +127,17 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) { } func TestVerifyAttestationReportUnknownProduct(t *testing.T) { - file, reportData := prepareForTestVerifyAttestationReport(t) + attestationPB, reportData := prepVerifyAttReport(t) tests := []struct { name string - attestationReport []byte + attestationReport *sevsnp.Attestation reportData []byte err error }{ { name: "Valid attestation, unknown product", - attestationReport: file, + attestationReport: attestationPB, reportData: reportData, err: errProductLine, }, @@ -149,8 +145,8 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - AttConfigurationSEVSNP.RootOfTrust.ProductLine = "" - AttConfigurationSEVSNP.Policy.Product = nil + config.AttestationPolicy.SnpCheck.RootOfTrust.ProductLine = "" + config.AttestationPolicy.SnpCheck.Policy.Product = nil err := VerifyAttestationReportTLS(tt.attestationReport, tt.reportData) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) @@ -158,20 +154,20 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) { } func TestVerifyAttestationReportMalformedPolicy(t *testing.T) { - file, reportData := prepareForTestVerifyAttestationReport(t) + attestationPB, reportData := prepVerifyAttReport(t) // Change random data in the measurement so the measurement does not match - file[measurementOffset] = file[measurementOffset] ^ 0x01 + attestationPB.Report.Measurement[0] = attestationPB.Report.Measurement[0] ^ 0x01 tests := []struct { name string - attestationReport []byte + attestationReport *sevsnp.Attestation reportData []byte err error }{ { name: "Valid attestation, malformed policy (measurement)", - attestationReport: file, + attestationReport: attestationPB, reportData: reportData, err: errAttVerification, }, @@ -185,32 +181,34 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) { } } -func prepareForTestVerifyAttestationReport(t *testing.T) ([]byte, []byte) { +func prepVerifyAttReport(t *testing.T) (*sevsnp.Attestation, []byte) { file, err := os.ReadFile("../../../attestation.bin") require.NoError(t, err) - rr, err := abi.ReportCertsToProto(file) - require.NoError(t, err) - if len(file) < attestationReportSize { file = append(file, make([]byte, attestationReportSize-len(file))...) } - AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} + rr, err := abi.ReportCertsToProto(file) + require.NoError(t, err) + + config.AttestationPolicy = config.Config{SnpCheck: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &config.PcrConfig{}} attestationPolicyFile, err := os.ReadFile("../../../scripts/attestation_policy/attestation_policy.json") require.NoError(t, err) - err = protojson.Unmarshal(attestationPolicyFile, &AttConfigurationSEVSNP) + unmarshalOptions := protojson.UnmarshalOptions{DiscardUnknown: true} + + err = unmarshalOptions.Unmarshal(attestationPolicyFile, config.AttestationPolicy.SnpCheck) require.NoError(t, err) - AttConfigurationSEVSNP.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN} - AttConfigurationSEVSNP.Policy.FamilyId = rr.Report.FamilyId - AttConfigurationSEVSNP.Policy.ImageId = rr.Report.ImageId - AttConfigurationSEVSNP.Policy.Measurement = rr.Report.Measurement - AttConfigurationSEVSNP.Policy.HostData = rr.Report.HostData - AttConfigurationSEVSNP.Policy.ReportIdMa = rr.Report.ReportIdMa - AttConfigurationSEVSNP.RootOfTrust.ProductLine = sevProductNameMilan + config.AttestationPolicy.SnpCheck.Policy.Product = &sevsnp.SevProduct{Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN} + config.AttestationPolicy.SnpCheck.Policy.FamilyId = rr.Report.FamilyId + config.AttestationPolicy.SnpCheck.Policy.ImageId = rr.Report.ImageId + config.AttestationPolicy.SnpCheck.Policy.Measurement = rr.Report.Measurement + config.AttestationPolicy.SnpCheck.Policy.HostData = rr.Report.HostData + config.AttestationPolicy.SnpCheck.Policy.ReportIdMa = rr.Report.ReportIdMa + config.AttestationPolicy.SnpCheck.RootOfTrust.ProductLine = sevProductNameMilan - return file, rr.Report.ReportData + return rr, rr.Report.ReportData } diff --git a/pkg/attestation/vtpm/vtpm.go b/pkg/attestation/vtpm/vtpm.go new file mode 100644 index 00000000..9102f9c3 --- /dev/null +++ b/pkg/attestation/vtpm/vtpm.go @@ -0,0 +1,307 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package vtpm + +import ( + "bytes" + "crypto" + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "encoding/hex" + "fmt" + "io" + "os" + "strconv" + + "github.com/absmach/magistrala/pkg/errors" + "github.com/google/go-sev-guest/abi" + "github.com/google/go-tpm-tools/client" + "github.com/google/go-tpm-tools/proto/attest" + "github.com/google/go-tpm-tools/proto/tpm" + "github.com/google/go-tpm-tools/server" + "github.com/google/go-tpm/legacy/tpm2" + "github.com/google/go-tpm/tpmutil" + config "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "golang.org/x/crypto/sha3" + "google.golang.org/protobuf/proto" +) + +const ( + eventLog = "/sys/kernel/security/tpm0/binary_bios_measurements" + Nonce = 32 + PCR15 = 15 + Hash256 = 32 + Hash384 = 48 +) + +var ( + ExternalTPM io.ReadWriteCloser + ErrNoHashAlgo = errors.New("hash algo is not supported") +) + +type tpmWrapper struct { + io.ReadWriteCloser +} + +func (et tpmWrapper) EventLog() ([]byte, error) { + return os.ReadFile(eventLog) +} + +func OpenTpm() (io.ReadWriteCloser, error) { + if ExternalTPM != nil { + return tpmWrapper{ExternalTPM}, nil + } + + tw := tpmWrapper{} + var err error + + tw.ReadWriteCloser, err = tpm2.OpenTPM("/dev/tpmrm0") + if os.IsNotExist(err) { + tw.ReadWriteCloser, err = tpm2.OpenTPM("/dev/tpm0") + } + + return tw, err +} + +func ExtendPCR(pcrIndex int, value []byte) error { + rwc, err := OpenTpm() + if err != nil { + return err + } + defer rwc.Close() + + fixedSha256Hash := sha3.Sum256(value) + if err := tpm2.PCRExtend(rwc, tpmutil.Handle(pcrIndex), tpm2.AlgSHA256, fixedSha256Hash[:], ""); err != nil { + return err + } + + fixedSha384Hash := sha3.Sum384(value) + if err := tpm2.PCRExtend(rwc, tpmutil.Handle(pcrIndex), tpm2.AlgSHA384, fixedSha384Hash[:], ""); err != nil { + return err + } + + return nil +} + +func Attest(teeNonce []byte, vTPMNonce []byte, teeAttestaion bool) ([]byte, error) { + attestation, err := fetchVTPMQuote(vTPMNonce) + if err != nil { + return []byte{}, err + } + + if teeAttestaion { + attestation, err = addTEEAttestation(attestation, teeNonce) + if err != nil { + return []byte{}, err + } + } + + return marshalQuote(attestation) +} + +func FetchATLSQuote(pubKey, teeNonce, vTPMNonce []byte) ([]byte, error) { + attestation, err := fetchVTPMQuote(vTPMNonce) + if err != nil { + return []byte{}, err + } + + reportData, err := createTEEAttestationReportNonce(pubKey, attestation.GetAkPub(), teeNonce) + if err != nil { + return []byte{}, err + } + + attestation, err = addTEEAttestation(attestation, reportData) + if err != nil { + return []byte{}, err + } + + return marshalQuote(attestation) +} + +func VTPMVerify(quote []byte, pubKeyTLS []byte, teeNonce []byte, vtpmNonce []byte) error { + attestation := &attest.Attestation{} + + err := proto.Unmarshal(quote, attestation) + if err != nil { + return fmt.Errorf("fail to unmarshal quote: %v", err) + } + + ak := attestation.GetAkPub() + pub, err := tpm2.DecodePublic(ak) + if err != nil { + return err + } + + cryptoPub, err := pub.Key() + if err != nil { + return err + } + + reportData, err := createTEEAttestationReportNonce(pubKeyTLS, ak, teeNonce) + if err != nil { + return fmt.Errorf("fail to calculate report data: %v", err) + } + + if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), reportData); err != nil { + return fmt.Errorf("failed to verify TEE attestation report: %v", err) + } + + _, err = server.VerifyAttestation(attestation, server.VerifyOpts{Nonce: vtpmNonce, TrustedAKs: []crypto.PublicKey{cryptoPub}}) + if err != nil { + return fmt.Errorf("verifying attestation: %w", err) + } + + s256, s384 := calculatePCRTLSKey(pubKeyTLS) + + if err := checkExpectedPCRValues(attestation, s256, s384); err != nil { + return fmt.Errorf("PCR values do not match expected PCR values: %w", err) + } + + return nil +} + +func publicKeyToBytes(pubKey interface{}) ([]byte, error) { + derBytes, err := x509.MarshalPKIXPublicKey(pubKey) + if err != nil { + return nil, err + } + return derBytes, nil +} + +func createTEEAttestationReportNonce(pubKeyTLS []byte, ak []byte, nonce []byte) ([]byte, error) { + pub, err := tpm2.DecodePublic(ak) + if err != nil { + return []byte{}, err + } + + cryptoPub, err := pub.Key() + if err != nil { + return []byte{}, err + } + + pubKeyBytes, err := publicKeyToBytes(cryptoPub) + if err != nil { + return []byte{}, err + } + + reportData := append(append(pubKeyTLS, pubKeyBytes...), nonce...) + hash := sha3.Sum512(reportData) + + return hash[:], nil +} + +func marshalQuote(attestation *attest.Attestation) ([]byte, error) { + out, err := proto.Marshal(attestation) + if err != nil { + return []byte{}, fmt.Errorf("failed to marshal vTPM attestation report: %v", err) + } + + return out, nil +} + +func fetchVTPMQuote(nonce []byte) (*attest.Attestation, error) { + rwc, err := OpenTpm() + if err != nil { + return nil, err + } + defer rwc.Close() + + attestationKey, err := client.AttestationKeyRSA(rwc) + if err != nil { + return nil, fmt.Errorf("failed to create attestation key: %v", err) + } + defer attestationKey.Close() + + var fixedNonce [Nonce]byte + copy(fixedNonce[:], nonce) + attestOpts := client.AttestOpts{} + attestOpts.Nonce = fixedNonce[:] + + attestOpts.TCGEventLog, err = client.GetEventLog(rwc) + if err != nil { + return nil, fmt.Errorf("failed to retrieve TCG Event Log: %w", err) + } + + attestation, err := attestationKey.Attest(attestOpts) + if err != nil { + return nil, fmt.Errorf("failed to collect attestation report: %v", err) + } + + return attestation, nil +} + +func addTEEAttestation(attestation *attest.Attestation, nonce []byte) (*attest.Attestation, error) { + rawTeeAttestation, err := quoteprovider.FetchAttestation(nonce) + if err != nil { + return attestation, fmt.Errorf("failed to fetch TEE attestation report: %v", err) + } + + extReport, err := abi.ReportCertsToProto(rawTeeAttestation) + if err != nil { + return attestation, fmt.Errorf("failed to export the TEE report: %v", err) + } + attestation.TeeAttestation = &attest.Attestation_SevSnpAttestation{ + SevSnpAttestation: extReport, + } + + return attestation, nil +} + +func checkExpectedPCRValues(attestation *attest.Attestation, ePcr256 []byte, ePcr384 []byte) error { + quotes := attestation.GetQuotes() + for i := range quotes { + quote := quotes[i] + var pcrMap map[string]string + var pcr15 []byte + switch quote.Pcrs.Hash { + case tpm.HashAlgo_SHA256: + pcrMap = config.AttestationPolicy.PcrConfig.PCRValues.Sha256 + pcr15 = ePcr256 + case tpm.HashAlgo_SHA384: + pcrMap = config.AttestationPolicy.PcrConfig.PCRValues.Sha384 + pcr15 = ePcr384 + default: + return errors.Wrap(ErrNoHashAlgo, fmt.Errorf("algo: %s", tpm.HashAlgo_name[int32(quote.Pcrs.Hash)])) + } + + pcr15Index := uint32(15) + if !bytes.Equal(quote.Pcrs.Pcrs[pcr15Index], pcr15) { + return fmt.Errorf("for algo %s PCR[15] expected %s but found %s", tpm.HashAlgo_name[int32(quote.Pcrs.Hash)], hex.EncodeToString(pcr15), hex.EncodeToString(quote.Pcrs.Pcrs[pcr15Index])) + } + + for i, v := range pcrMap { + index, err := strconv.ParseInt(i, 10, 32) + if err != nil { + return fmt.Errorf("error converting PCR index to int32: %v\n", err) + } + value, err := hex.DecodeString(v) + if err != nil { + return fmt.Errorf("error converting PCR value to byte: %v\n", err) + } + if !bytes.Equal(quote.Pcrs.Pcrs[uint32(index)], value) { + return fmt.Errorf("for algo %s PCR[%d] expected %s but found %s", tpm.HashAlgo_name[int32(quote.Pcrs.Hash)], index, hex.EncodeToString(value), hex.EncodeToString(quote.Pcrs.Pcrs[uint32(index)])) + } + } + } + return nil +} + +// Return SHA256 and SHA384 values of the input public key. +func calculatePCRTLSKey(pubKey []byte) ([]byte, []byte) { + init256 := make([]byte, Hash256) + init384 := make([]byte, Hash384) + + key256 := sha3.Sum256(pubKey) + key384 := sha3.Sum384(pubKey) + + pcrValue256 := append(init256, key256[:]...) + pcrValue384 := append(init384, key384[:]...) + + newPcr256 := sha256.Sum256(pcrValue256) + newPcr384 := sha512.Sum384(pcrValue384) + + return newPcr256[:], newPcr384[:] +} diff --git a/pkg/clients/grpc/agent/agent_test.go b/pkg/clients/grpc/agent/agent_test.go index 1a97a759..060071ec 100644 --- a/pkg/clients/grpc/agent/agent_test.go +++ b/pkg/clients/grpc/agent/agent_test.go @@ -15,6 +15,7 @@ import ( "github.com/ultravioletrs/cocos/agent" agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc" "github.com/ultravioletrs/cocos/agent/mocks" + config "github.com/ultravioletrs/cocos/pkg/attestation" pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc" "google.golang.org/grpc" "google.golang.org/grpc/health" @@ -112,7 +113,7 @@ func TestAgentClientIntegration(t *testing.T) { }, AttestedTLS: true, }, - err: pkggrpc.ErrAttestationPolicyMissing, + err: config.ErrAttestationPolicyMissing, }, } diff --git a/pkg/clients/grpc/atls.go b/pkg/clients/grpc/atls.go index 3ba512b8..5a9fb5d1 100644 --- a/pkg/clients/grpc/atls.go +++ b/pkg/clients/grpc/atls.go @@ -16,12 +16,12 @@ import ( "github.com/absmach/magistrala/pkg/errors" "github.com/ultravioletrs/cocos/pkg/atls" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + config "github.com/ultravioletrs/cocos/pkg/attestation" "google.golang.org/grpc/credentials" ) func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error) { - err := ReadAttestationPolicy(cfg.AttestationPolicy, "eprovider.AttConfigurationSEVSNP) + err := config.ReadAttestationPolicy(cfg.AttestationPolicy, &config.AttestationPolicy) if err != nil { return nil, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err) } diff --git a/pkg/clients/grpc/connect_test.go b/pkg/clients/grpc/connect_test.go index 863c9654..7482de25 100644 --- a/pkg/clients/grpc/connect_test.go +++ b/pkg/clients/grpc/connect_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-sev-guest/proto/check" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + att "github.com/ultravioletrs/cocos/pkg/attestation" ) func TestNewClient(t *testing.T) { @@ -200,8 +201,9 @@ func TestClientSecure(t *testing.T) { } func TestReadAttestationPolicy(t *testing.T) { - validJSON := `{"policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}` + validJSON := `{"pcr_values":{"sha256":{"0":"123"},"sha384":{"0":"123"}},"policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}` invalidJSON := `{"invalid_json"` + invalidJSONPCR := `{"pcr_values":{"sha256":{"0":true},"sha384":{"0":"123"}},"policy":{"report_data":"AAAA"},"root_of_trust":{"product_line":"Milan"}}` cases := []struct { name string @@ -219,19 +221,25 @@ func TestReadAttestationPolicy(t *testing.T) { name: "Invalid JSON", manifestPath: "invalid_manifest.json", fileContent: invalidJSON, - err: ErrAttestationPolicyDecode, + err: att.ErrAttestationPolicyDecode, }, { name: "Non-existent file", manifestPath: "nonexistent.json", fileContent: "", - err: errAttestationPolicyOpen, + err: att.ErrAttestationPolicyOpen, }, { name: "Empty manifest path", manifestPath: "", fileContent: "", - err: ErrAttestationPolicyMissing, + err: att.ErrAttestationPolicyMissing, + }, + { + name: "Invalid JSON PCR", + manifestPath: "invalid_manifest.json", + fileContent: invalidJSONPCR, + err: att.ErrAttestationPolicyDecode, }, } @@ -243,13 +251,13 @@ func TestReadAttestationPolicy(t *testing.T) { defer os.Remove(tt.manifestPath) } - config := check.Config{} - err := ReadAttestationPolicy(tt.manifestPath, &config) + config := att.Config{SnpCheck: &check.Config{}, PcrConfig: &att.PcrConfig{}} + err := att.ReadAttestationPolicy(tt.manifestPath, &config) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) if tt.err == nil { - assert.NotNil(t, config.Policy) - assert.NotNil(t, config.RootOfTrust) + assert.NotNil(t, config.SnpCheck.Policy) + assert.NotNil(t, config.SnpCheck.RootOfTrust) } }) } diff --git a/pkg/clients/grpc/grpc.go b/pkg/clients/grpc/grpc.go index dec55304..b9445c44 100644 --- a/pkg/clients/grpc/grpc.go +++ b/pkg/clients/grpc/grpc.go @@ -11,12 +11,10 @@ import ( "time" "github.com/absmach/magistrala/pkg/errors" - "github.com/google/go-sev-guest/proto/check" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/protobuf/encoding/protojson" ) type security int @@ -37,9 +35,6 @@ const ( var ( errGrpcConnect = errors.New("failed to connect to grpc server") errGrpcClose = errors.New("failed to close grpc connection") - errAttestationPolicyOpen = errors.New("failed to open Attestation Policy file") - ErrAttestationPolicyMissing = errors.New("failed due to missing Attestation Policy file") - ErrAttestationPolicyDecode = errors.New("failed to decode Attestation Policy file") errCertificateParse = errors.New("failed to parse x509 certificate") errAttVerification = errors.New("certificat is not sefl signed") errFailedToLoadClientCertKey = errors.New("failed to load client certificate and key") @@ -55,7 +50,7 @@ type BaseConfig struct { Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"` ClientCert string `env:"CLIENT_CERT" envDefault:""` ClientKey string `env:"CLIENT_KEY" envDefault:""` - ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` + ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` } type AgentClientConfig struct { @@ -146,7 +141,9 @@ func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) { if err != nil { return nil, secure, err } + opts = append(opts, grpc.WithTransportCredentials(tc)) + opts = append(opts, grpc.WithContextDialer(CustomDialer)) secure = withaTLS } else { conf := cfg.GetBaseConfig() @@ -198,20 +195,3 @@ func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.Tran return tc, nil, secure } - -func ReadAttestationPolicy(manifestPath string, attestationConfiguration *check.Config) error { - if manifestPath != "" { - manifest, err := os.ReadFile(manifestPath) - if err != nil { - return errors.Wrap(errAttestationPolicyOpen, err) - } - - if err := protojson.Unmarshal(manifest, attestationConfiguration); err != nil { - return errors.Wrap(ErrAttestationPolicyDecode, err) - } - - return nil - } - - return ErrAttestationPolicyMissing -} diff --git a/pkg/sdk/agent.go b/pkg/sdk/agent.go index 7b8cbc68..5639c29a 100644 --- a/pkg/sdk/agent.go +++ b/pkg/sdk/agent.go @@ -26,11 +26,12 @@ type SDK interface { Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error Data(ctx context.Context, dataset *os.File, filename string, privKey any) error Result(ctx context.Context, privKey any, resultFile *os.File) error - Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error + Attestation(ctx context.Context, reportData [size64]byte, nonce [size32]byte, attType int, attestationFile *os.File) error } const ( size64 = 64 + size32 = 32 algoProgressBarDescription = "Uploading algorithm" dataProgressBarDescription = "Uploading data" resultProgressDescription = "Downloading result" @@ -120,9 +121,11 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any, resultFile *os.Fil return pb.ReceiveResult(resultProgressDescription, fileSize, stream, resultFile) } -func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error { +func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, nonce [size32]byte, attType int, attestationFile *os.File) error { request := &agent.AttestationRequest{ - ReportData: reportData[:], + TeeNonce: reportData[:], + VtpmNonce: nonce[:], + Type: int32(attType), } stream, err := sdk.client.Attestation(ctx, request) diff --git a/pkg/sdk/agent_test.go b/pkg/sdk/agent_test.go index c3379722..4063ec6c 100644 --- a/pkg/sdk/agent_test.go +++ b/pkg/sdk/agent_test.go @@ -19,6 +19,8 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/sdk" "golang.org/x/crypto/sha3" "google.golang.org/grpc" @@ -364,6 +366,7 @@ func TestAttestation(t *testing.T) { resultConsumer1Key, _ := generateKeys(t, "ed25519") reportData := make([]byte, 64) + nonce := make([]byte, 64) report := []byte{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, @@ -385,7 +388,8 @@ func TestAttestation(t *testing.T) { cases := []struct { name string userKey any - reportData [agent.ReportDataSize]byte + reportData [quoteprovider.Nonce]byte + nonce [vtpm.Nonce]byte response *agent.AttestationResponse svcRes []byte err error @@ -393,7 +397,8 @@ func TestAttestation(t *testing.T) { { name: "fetch attestation report successfully", userKey: resultConsumerKey, - reportData: [agent.ReportDataSize]byte(reportData), + reportData: [quoteprovider.Nonce]byte(reportData), + nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: report, }, @@ -403,7 +408,8 @@ func TestAttestation(t *testing.T) { { name: "fetch attestation report with different key type", userKey: resultConsumer1Key, - reportData: [agent.ReportDataSize]byte(reportData), + reportData: [quoteprovider.Nonce]byte(reportData), + nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: report, }, @@ -413,7 +419,8 @@ func TestAttestation(t *testing.T) { { name: "failed to fetch attestation report", userKey: resultConsumerKey, - reportData: [agent.ReportDataSize]byte(reportData), + reportData: [quoteprovider.Nonce]byte(reportData), + nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: []byte{}, }, @@ -422,7 +429,8 @@ func TestAttestation(t *testing.T) { { name: "invalid report data", userKey: resultConsumerKey, - reportData: [agent.ReportDataSize]byte{}, + reportData: [quoteprovider.Nonce]byte{}, + nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: []byte{}, }, @@ -433,7 +441,7 @@ func TestAttestation(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) + svcCall := svc.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) file, err := os.CreateTemp("", "attestation") require.NoError(t, err) @@ -442,7 +450,7 @@ func TestAttestation(t *testing.T) { os.Remove(file.Name()) }) - err = sdk.Attestation(context.Background(), tc.reportData, file) + err = sdk.Attestation(context.Background(), tc.reportData, tc.nonce, 0, file) require.NoError(t, file.Close()) diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index 47e5909a..1c0d93ea 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -74,17 +74,17 @@ func (_c *SDK_Algo_Call) RunAndReturn(run func(context.Context, *os.File, *os.Fi return _c } -// Attestation provides a mock function with given fields: ctx, reportData, attestationFile -func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, attestationFile *os.File) error { - ret := _m.Called(ctx, reportData, attestationFile) +// Attestation provides a mock function with given fields: ctx, reportData, nonce, attType, attestationFile +func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType int, attestationFile *os.File) error { + ret := _m.Called(ctx, reportData, nonce, attType, attestationFile) if len(ret) == 0 { panic("no return value specified for Attestation") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, [64]byte, *os.File) error); ok { - r0 = rf(ctx, reportData, attestationFile) + if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, int, *os.File) error); ok { + r0 = rf(ctx, reportData, nonce, attType, attestationFile) } else { r0 = ret.Error(0) } @@ -100,14 +100,16 @@ type SDK_Attestation_Call struct { // Attestation is a helper method to define mock.On call // - ctx context.Context // - reportData [64]byte +// - nonce [32]byte +// - attType int // - attestationFile *os.File -func (_e *SDK_Expecter) Attestation(ctx interface{}, reportData interface{}, attestationFile interface{}) *SDK_Attestation_Call { - return &SDK_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, attestationFile)} +func (_e *SDK_Expecter) Attestation(ctx interface{}, reportData interface{}, nonce interface{}, attType interface{}, attestationFile interface{}) *SDK_Attestation_Call { + return &SDK_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, nonce, attType, attestationFile)} } -func (_c *SDK_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, attestationFile *os.File)) *SDK_Attestation_Call { +func (_c *SDK_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType int, attestationFile *os.File)) *SDK_Attestation_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([64]byte), args[2].(*os.File)) + run(args[0].(context.Context), args[1].([64]byte), args[2].([32]byte), args[3].(int), args[4].(*os.File)) }) return _c } @@ -117,7 +119,7 @@ func (_c *SDK_Attestation_Call) Return(_a0 error) *SDK_Attestation_Call { return _c } -func (_c *SDK_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, *os.File) error) *SDK_Attestation_Call { +func (_c *SDK_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, [32]byte, int, *os.File) error) *SDK_Attestation_Call { _c.Call.Return(run) return _c } diff --git a/scripts/attestation_policy/attestation_policy.json b/scripts/attestation_policy/attestation_policy.json index 460066a3..900019fc 100644 --- a/scripts/attestation_policy/attestation_policy.json +++ b/scripts/attestation_policy/attestation_policy.json @@ -1,28 +1,52 @@ { - "policy": { - "policy": 196608, - "family_id": "AAAAAAAAAAAAAAAAAAAAAA==", - "image_id": "AAAAAAAAAAAAAAAAAAAAAA==", - "vmpl": 0, - "minimum_tcb": 15352208179752599555, - "minimum_launch_tcb": 15352208179752599555, - "require_author_key": false, - "measurement": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", - "host_data": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", - "report_id_ma": "//////////////////////////////////////////8=", - "chip_id": "GrFqtQ+lrkLsjBslu9pcC6XqkrtFWY1ArIQ+I4gugQIsvCG0qekSvEtE4P/SLSJ6mHNpOkY0MHnGpvz1OkV+kw==", - "minimum_build": 8, - "minimum_version": "1.55", - "permit_provisional_firmware": true, - "require_id_block": false, - "product": { - "name": 1 + "pcr_values": { + "sha256": { + "0": "71e0cc99e4609fdbc44698cceeda9e5ecb2f74fe07bd10710d5330e0eb6bd32b", + "1": "a40e22460c21d2450367ca70c751ec0ae5ae1072994a131287a96eadc295603b", + "2": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", + "3": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", + "4": "e16812b9181e13078b29f2e4844be7087f9e1bbffc3cb4171d2813580cafdb8d", + "5": "a5ceb755d043f32431d63e39f5161464620a3437280494b5850dc1b47cc074e0", + "6": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", + "7": "70d12f32fdb109ba0960697b5a8d5d8d860b004a757fe2471be2c2a19ec1a765", + "9": "2add30b0f2b31480ee5eb802c436cfffe77ceebc6009e063e84fc6a6ef2c05ac" + }, + "sha384": { + "0": "ff93a763afde2c4a152d4843d9fcabe73a70d4f34bf8861845f2ab08440c1f0742b5882ed7f2524e38a3a6e40fbcdfca", + "1": "c9b3bcc22d856cbc5be2a2bf72d81819df325db083cfea20e84d082a87f44d643e6fca98f29eb3cce4c87eed2dbca2e5", + "2": "518923b0f955d08da077c96aaba522b9decede61c599cea6c41889cfbea4ae4d50529d96fe4d1afdafb65e7f95bf23c4", + "3": "518923b0f955d08da077c96aaba522b9decede61c599cea6c41889cfbea4ae4d50529d96fe4d1afdafb65e7f95bf23c4", + "4": "d18d213c26e7bc309e52448bde2f0a8ef86be388223f64f85c4e0c625f1e0a7f8c901d4f7c98f8445730bc63c4dfa88d", + "5": "c50b529497c7f441ea47305587d6ce83e2e31f7b4fab6c13dc0b0c3c900e1d0caf0768321100927862df142bf0465ee4", + "6": "518923b0f955d08da077c96aaba522b9decede61c599cea6c41889cfbea4ae4d50529d96fe4d1afdafb65e7f95bf23c4", + "7": "ea40cbd8f51eed103d75821340e71fa3c0cfde3e75c360b4c9aca534b7fed021e12f8890acef36ccfe12b33ea4111576", + "9": "02556c6b494abaf21481def35b38574e80dc68f20ceb8385f78a5ad4ecfbab60f9fcfca7c69f09a081fdd4ca13f3c14d" } }, + "policy": { + "chip_id": "GrFqtQ+lrkLsjBslu9pcC6XqkrtFWY1ArIQ+I4gugQIsvCG0qekSvEtE4P/SLSJ6mHNpOkY0MHnGpvz1OkV+kw==", + "family_id": "AAAAAAAAAAAAAAAAAAAAAA==", + "host_data": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "image_id": "AAAAAAAAAAAAAAAAAAAAAA==", + "measurement": "oDYo4e98Da2Fy73nDVZmxiWiz+5gnxae7NMRtdfnwpbBuVYZsI0mynz3fpfe+YIX", + "minimum_build": 8, + "minimum_launch_tcb": 15352208179752599555, + "minimum_tcb": 15352208179752599555, + "minimum_version": "1.55", + "permit_provisional_firmware": true, + "policy": 196608, + "product": { + "name": 1 + }, + "report_id_ma": "//////////////////////////////////////////8=", + "require_author_key": false, + "require_id_block": false, + "vmpl": 2 + }, "root_of_trust": { - "product": "Milan", "check_crl": true, "disallow_network": false, + "product": "Milan", "product_line": "Milan" } } diff --git a/scripts/attestation_policy/pcr_values.json b/scripts/attestation_policy/pcr_values.json new file mode 100644 index 00000000..67286ac5 --- /dev/null +++ b/scripts/attestation_policy/pcr_values.json @@ -0,0 +1,26 @@ +{ + "pcr_values": { + "sha256": { + "0": "71e0cc99e4609fdbc44698cceeda9e5ecb2f74fe07bd10710d5330e0eb6bd32b", + "1": "a40e22460c21d2450367ca70c751ec0ae5ae1072994a131287a96eadc295603b", + "2": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", + "3": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", + "4": "e16812b9181e13078b29f2e4844be7087f9e1bbffc3cb4171d2813580cafdb8d", + "5": "a5ceb755d043f32431d63e39f5161464620a3437280494b5850dc1b47cc074e0", + "6": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", + "7": "70d12f32fdb109ba0960697b5a8d5d8d860b004a757fe2471be2c2a19ec1a765", + "9": "2add30b0f2b31480ee5eb802c436cfffe77ceebc6009e063e84fc6a6ef2c05ac" + }, + "sha384": { + "0": "ff93a763afde2c4a152d4843d9fcabe73a70d4f34bf8861845f2ab08440c1f0742b5882ed7f2524e38a3a6e40fbcdfca", + "1": "c9b3bcc22d856cbc5be2a2bf72d81819df325db083cfea20e84d082a87f44d643e6fca98f29eb3cce4c87eed2dbca2e5", + "2": "518923b0f955d08da077c96aaba522b9decede61c599cea6c41889cfbea4ae4d50529d96fe4d1afdafb65e7f95bf23c4", + "3": "518923b0f955d08da077c96aaba522b9decede61c599cea6c41889cfbea4ae4d50529d96fe4d1afdafb65e7f95bf23c4", + "4": "d18d213c26e7bc309e52448bde2f0a8ef86be388223f64f85c4e0c625f1e0a7f8c901d4f7c98f8445730bc63c4dfa88d", + "5": "c50b529497c7f441ea47305587d6ce83e2e31f7b4fab6c13dc0b0c3c900e1d0caf0768321100927862df142bf0465ee4", + "6": "518923b0f955d08da077c96aaba522b9decede61c599cea6c41889cfbea4ae4d50529d96fe4d1afdafb65e7f95bf23c4", + "7": "ea40cbd8f51eed103d75821340e71fa3c0cfde3e75c360b4c9aca534b7fed021e12f8890acef36ccfe12b33ea4111576", + "9": "02556c6b494abaf21481def35b38574e80dc68f20ceb8385f78a5ad4ecfbab60f9fcfca7c69f09a081fdd4ca13f3c14d" + } + } +} diff --git a/scripts/attestation_policy/src/main.rs b/scripts/attestation_policy/src/main.rs index 7cbd308b..2d7491e4 100644 --- a/scripts/attestation_policy/src/main.rs +++ b/scripts/attestation_policy/src/main.rs @@ -1,12 +1,15 @@ use base64::prelude::*; use clap::{value_parser, Arg, Command}; use serde::Serialize; +use serde_json::Value; use sev::firmware::host::*; use std::arch::x86_64::__cpuid; -use std::fs::File; +use std::fs::{read_to_string, File}; use std::io::Write; const ATTESTATION_POLICY_JSON: &str = "attestation_policy.json"; +const PCR_VALUES_JSON: &str = "pcr_values.json"; + const EXTENDED_FAMILY_SHIFT: u32 = 20; const EXTENDED_MODEL_SHIFT: u32 = 16; const FAMILY_SHIFT: u32 = 8; @@ -123,7 +126,7 @@ fn main() { let policy: u64 = *matches.get_one::("policy").unwrap(); let family_id = BASE64_STANDARD.encode(vec![0; 16]); let image_id = BASE64_STANDARD.encode(vec![0; 16]); - let vmpl = 0; + let vmpl = 2; let minimum_tcb = get_uint64_from_tcb(&status.platform_tcb_version); let minimum_launch_tcb = get_uint64_from_tcb(&status.platform_tcb_version); let require_author_key = false; @@ -169,10 +172,33 @@ fn main() { root_of_trust, }; - let json = serde_json::to_string_pretty(&computation).expect("Failed to serialize to JSON"); - let mut file = File::create(ATTESTATION_POLICY_JSON).expect("Failed to create file"); - file.write_all(json.as_bytes()) - .expect("Failed to write to file"); + let mut computation_value = + serde_json::to_value(&computation).expect("Failed to convert computation to JSON"); + + // Read and parse the pcr_values.json file. + let pcr_content = read_to_string(PCR_VALUES_JSON).expect("Failed to read pcr_values.json"); + let pcr_value: Value = + serde_json::from_str(&pcr_content).expect("Failed to parse pcr_values.json"); + + // Merge the pcr_values into the main JSON object. + if let Value::Object(ref mut main_map) = computation_value { + if let Value::Object(pcr_map) = pcr_value { + // The keys in pcr_map (e.g., "pcr_values") will be added + main_map.extend(pcr_map); + } else { + eprintln!("{} is not a JSON object.", PCR_VALUES_JSON); + } + } else { + eprintln!("The computed JSON is not an object."); + } + + // Serialize the merged JSON and write to file. + let merged_json = + serde_json::to_string_pretty(&computation_value).expect("Failed to serialize merged JSON"); + let mut file = + File::create(ATTESTATION_POLICY_JSON).expect("Failed to create attestation policy file"); + file.write_all(merged_json.as_bytes()) + .expect("Failed to write merged JSON to file"); println!( "AttestationPolicy JSON has been written to {}", diff --git a/test/manual/README.md b/test/manual/README.md index 9694015c..a9f31cb5 100644 --- a/test/manual/README.md +++ b/test/manual/README.md @@ -67,7 +67,15 @@ export AGENT_GRPC_ATTESTATION_POLICY=./scripts/attestation_policy/attestation_po export AGENT_GRPC_ATTESTED_TLS=true # Retrieve Attestation -./build/cocos-cli attestation get '' +# Three different attestation reports can be retrieved: +# - SEV-SNP with argument snp for attestation get command. +./build/cocos-cli attestation get snp --tee '' + +# - vTPM with argument vtpm for attestation get command. +./build/cocos-cli attestation get vtpm --vtpm '' + +# - vTPM with SEV-SNP with argument snp-vtpm for attestation get command. +./build/cocos-cli attestation get snp-vtpm --tee '' --vtpm '' # Validate Attestation # Product name must be Milan or Genoa