COCOS-430 - Refactor gRPC server handlers to use a map for improved organization and add validation for nonce lengths in attestation requests (#477)

* Refactor gRPC server handlers to use a map for improved organization and add validation for nonce lengths in attestation requests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance nonce validation and improve error handling in gRPC server methods

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2025-07-14 14:46:02 +03:00
committed by GitHub
parent 45187d7f41
commit 85a2b7a6c8
2 changed files with 570 additions and 209 deletions
+290 -185
View File
@@ -10,6 +10,7 @@ import (
"io"
"strconv"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/transport/grpc"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/attestation"
@@ -34,54 +35,69 @@ var (
var _ agent.AgentServiceServer = (*grpcServer)(nil)
type grpcServer struct {
algo grpc.Handler
data grpc.Handler
result grpc.Handler
attestation grpc.Handler
imaMeasurements grpc.Handler
attestationResult grpc.Handler
handlers map[string]grpc.Handler
agent.UnimplementedAgentServiceServer
}
type endpointConfig struct {
endpoint func(agent.Service) endpoint.Endpoint
decodeRequest grpc.DecodeRequestFunc
encodeResponse grpc.EncodeResponseFunc
}
// NewServer returns new AgentServiceServer instance.
func NewServer(svc agent.Service) agent.AgentServiceServer {
// Define endpoint configurations
endpoints := map[string]endpointConfig{
"algo": {
endpoint: algoEndpoint,
decodeRequest: decodeAlgoRequest,
encodeResponse: encodeAlgoResponse,
},
"data": {
endpoint: dataEndpoint,
decodeRequest: decodeDataRequest,
encodeResponse: encodeDataResponse,
},
"result": {
endpoint: resultEndpoint,
decodeRequest: decodeResultRequest,
encodeResponse: encodeResultResponse,
},
"attestation": {
endpoint: attestationEndpoint,
decodeRequest: decodeAttestationRequest,
encodeResponse: encodeAttestationResponse,
},
"imaMeasurements": {
endpoint: imaMeasurementsEndpoint,
decodeRequest: decodeIMAMeasurementsRequest,
encodeResponse: encodeIMAMeasurementsResponse,
},
"attestationResult": {
endpoint: attestationResultEndpoint,
decodeRequest: decodeAttestationResultRequest,
encodeResponse: encodeAttestationResultResponse,
},
}
// Create handlers using the configurations
handlers := make(map[string]grpc.Handler)
for name, config := range endpoints {
handlers[name] = grpc.NewServer(
config.endpoint(svc),
config.decodeRequest,
config.encodeResponse,
)
}
return &grpcServer{
algo: grpc.NewServer(
algoEndpoint(svc),
decodeAlgoRequest,
encodeAlgoResponse,
),
data: grpc.NewServer(
dataEndpoint(svc),
decodeDataRequest,
encodeDataResponse,
),
result: grpc.NewServer(
resultEndpoint(svc),
decodeResultRequest,
encodeResultResponse,
),
attestation: grpc.NewServer(
attestationEndpoint(svc),
decodeAttestationRequest,
encodeAttestationResponse,
),
imaMeasurements: grpc.NewServer(
imaMeasurementsEndpoint(svc),
decodeIMAMeasurementsRequest,
encodeIMAMeasurementsResponse,
),
attestationResult: grpc.NewServer(
attestationResultEndpoint(svc),
decodeAttestationResultRequest,
encodeAttestationResultResponse,
),
handlers: handlers,
}
}
func decodeAlgoRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*agent.AlgoRequest)
return algoReq{
Algorithm: req.Algorithm,
Requirements: req.Requirements,
@@ -94,7 +110,6 @@ func encodeAlgoResponse(_ context.Context, response interface{}) (interface{}, e
func decodeDataRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*agent.DataRequest)
return dataReq{
Dataset: req.Dataset,
Filename: req.Filename,
@@ -116,22 +131,47 @@ func encodeResultResponse(_ context.Context, response interface{}) (interface{},
}, nil
}
func validateNonce(nonce []byte, maxLen int, target interface{}) error {
if len(nonce) > maxLen {
switch maxLen {
case quoteprovider.Nonce:
return ErrTEENonceLength
case vtpm.Nonce:
return ErrVTPMNonceLength
default:
return ErrTokenNonceLength
}
}
switch t := target.(type) {
case *[quoteprovider.Nonce]byte:
copy(t[:], nonce)
case *[vtpm.Nonce]byte:
copy(t[:], nonce)
default:
return fmt.Errorf("unsupported target type for nonce validation: %T", target)
}
return nil
}
func decodeAttestationRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*agent.AttestationRequest)
var reportData [quoteprovider.Nonce]byte
var nonce [vtpm.Nonce]byte
if len(req.TeeNonce) > quoteprovider.Nonce {
return nil, ErrTEENonceLength
if err := validateNonce(req.TeeNonce, quoteprovider.Nonce, &reportData); err != nil {
return nil, err
}
if len(req.VtpmNonce) > vtpm.Nonce {
return nil, ErrVTPMNonceLength
if err := validateNonce(req.VtpmNonce, vtpm.Nonce, &nonce); err != nil {
return nil, err
}
copy(reportData[:], req.TeeNonce)
copy(nonce[:], req.VtpmNonce)
return attestationReq{TeeNonce: reportData, VtpmNonce: nonce, AttType: attestation.PlatformType(req.Type)}, nil
return attestationReq{
TeeNonce: reportData,
VtpmNonce: nonce,
AttType: attestation.PlatformType(req.Type),
}, nil
}
func encodeAttestationResponse(_ context.Context, response interface{}) (interface{}, error) {
@@ -141,6 +181,20 @@ func encodeAttestationResponse(_ context.Context, response interface{}) (interfa
}, nil
}
func decodeAttestationResultRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*agent.AttestationResultRequest)
var nonce [vtpm.Nonce]byte
if err := validateNonce(req.TokenNonce, vtpm.Nonce, &nonce); err != nil {
return nil, err
}
return FetchAttestationResultReq{
tokenNonce: nonce,
AttType: attestation.PlatformType(req.Type),
}, nil
}
func encodeAttestationResultResponse(_ context.Context, response interface{}) (interface{}, error) {
res := response.(fetchAttestationResultRes)
return &agent.AttestationResultResponse{
@@ -148,127 +202,6 @@ func encodeAttestationResultResponse(_ context.Context, response interface{}) (i
}, nil
}
func decodeAttestationResultRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*agent.AttestationResultRequest)
var nonce [vtpm.Nonce]byte
if len(req.TokenNonce) > vtpm.Nonce {
return nil, ErrVTPMNonceLength
}
copy(nonce[:], req.TokenNonce)
return FetchAttestationResultReq{tokenNonce: nonce, AttType: attestation.PlatformType(req.Type)}, nil
}
// Algo implements agent.AgentServiceServer.
func (s *grpcServer) Algo(stream agent.AgentService_AlgoServer) error {
var algoFile, reqFile []byte
for {
algoChunk, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return status.Error(codes.Internal, err.Error())
}
algoFile = append(algoFile, algoChunk.Algorithm...)
reqFile = append(reqFile, algoChunk.Requirements...)
}
_, res, err := s.algo.ServeGRPC(stream.Context(), &agent.AlgoRequest{Algorithm: algoFile, Requirements: reqFile})
if err != nil {
return err
}
ar := res.(*agent.AlgoResponse)
return stream.SendAndClose(ar)
}
// Data implements agent.AgentServiceServer.
func (s *grpcServer) Data(stream agent.AgentService_DataServer) error {
var dataFile []byte
var filename string
for {
dataChunk, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return status.Error(codes.Internal, err.Error())
}
dataFile = append(dataFile, dataChunk.Dataset...)
filename = dataChunk.Filename
}
_, res, err := s.data.ServeGRPC(stream.Context(), &agent.DataRequest{Dataset: dataFile, Filename: filename})
if err != nil {
return err
}
ar := res.(*agent.DataResponse)
return stream.SendAndClose(ar)
}
func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_ResultServer) error {
_, res, err := s.result.ServeGRPC(stream.Context(), req)
if err != nil {
return err
}
rr := res.(*agent.ResultResponse)
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
return status.Error(codes.Internal, err.Error())
}
resultBuffer := bytes.NewBuffer(rr.File)
buf := make([]byte, bufferSize)
for {
n, err := resultBuffer.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return status.Error(codes.Internal, err.Error())
}
if err := stream.Send(&agent.ResultResponse{File: buf[:n]}); err != nil {
return status.Error(codes.Internal, err.Error())
}
}
return nil
}
func (s *grpcServer) Attestation(req *agent.AttestationRequest, stream agent.AgentService_AttestationServer) error {
_, res, err := s.attestation.ServeGRPC(stream.Context(), req)
if err != nil {
return err
}
rr := res.(*agent.AttestationResponse)
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
return status.Error(codes.Internal, err.Error())
}
attestationBuffer := bytes.NewBuffer(rr.File)
buf := make([]byte, bufferSize)
for {
n, err := attestationBuffer.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return status.Error(codes.Internal, err.Error())
}
if err := stream.Send(&agent.AttestationResponse{File: buf[:n]}); err != nil {
return status.Error(codes.Internal, err.Error())
}
}
return nil
}
func decodeIMAMeasurementsRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
return imaMeasurementsReq{}, nil
}
@@ -281,56 +214,228 @@ func encodeIMAMeasurementsResponse(_ context.Context, response interface{}) (int
}, nil
}
func (s *grpcServer) streamingHandler(
ctx context.Context,
handlerName string,
req interface{},
stream interface{},
sendFn func([]byte) error,
getFileData func(interface{}) []byte,
) error {
handler, ok := s.handlers[handlerName]
if !ok {
return status.Errorf(codes.NotFound, "handler %q not found", handlerName)
}
_, res, err := handler.ServeGRPC(ctx, req)
if err != nil {
return err
}
fileData := getFileData(res)
// Set file size header
if setter, ok := stream.(interface{ SetHeader(metadata.MD) error }); ok {
if err := setter.SetHeader(metadata.New(map[string]string{
FileSizeKey: fmt.Sprint(len(fileData)),
})); err != nil {
return status.Error(codes.Internal, err.Error())
}
}
// Stream the file data
return s.streamFileData(bytes.NewBuffer(fileData), sendFn)
}
func (s *grpcServer) streamFileData(buffer *bytes.Buffer, sendFn func([]byte) error) error {
buf := make([]byte, bufferSize)
for {
n, err := buffer.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return status.Error(codes.Internal, err.Error())
}
if err := sendFn(buf[:n]); err != nil {
return status.Error(codes.Internal, err.Error())
}
}
return nil
}
func receiveStreamingData(getData func() ([]byte, string, error)) ([]byte, string, error) {
var data []byte
var filename string
for {
chunk, fname, err := getData()
if err == io.EOF {
break
}
if err != nil {
return nil, "", status.Error(codes.Internal, err.Error())
}
data = append(data, chunk...)
if fname != "" {
filename = fname
}
}
return data, filename, nil
}
// Algo implements agent.AgentServiceServer.
func (s *grpcServer) Algo(stream agent.AgentService_AlgoServer) error {
algoFile, reqFile, err := s.receiveAlgoData(stream)
if err != nil {
return err
}
_, res, err := s.handlers["algo"].ServeGRPC(stream.Context(), &agent.AlgoRequest{
Algorithm: algoFile,
Requirements: reqFile,
})
if err != nil {
return err
}
return stream.SendAndClose(res.(*agent.AlgoResponse))
}
func (s *grpcServer) receiveAlgoData(stream agent.AgentService_AlgoServer) ([]byte, []byte, error) {
var algoFile, reqFile []byte
for {
chunk, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return nil, nil, status.Error(codes.Internal, err.Error())
}
algoFile = append(algoFile, chunk.Algorithm...)
reqFile = append(reqFile, chunk.Requirements...)
}
return algoFile, reqFile, nil
}
// Data implements agent.AgentServiceServer.
func (s *grpcServer) Data(stream agent.AgentService_DataServer) error {
dataFile, filename, err := receiveStreamingData(func() ([]byte, string, error) {
chunk, err := stream.Recv()
if err != nil {
return nil, "", err
}
return chunk.Dataset, chunk.Filename, nil
})
if err != nil {
return err
}
_, res, err := s.handlers["data"].ServeGRPC(stream.Context(), &agent.DataRequest{
Dataset: dataFile,
Filename: filename,
})
if err != nil {
return err
}
return stream.SendAndClose(res.(*agent.DataResponse))
}
func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_ResultServer) error {
return s.streamingHandler(
stream.Context(),
"result",
req,
stream,
func(data []byte) error {
return stream.Send(&agent.ResultResponse{File: data})
},
func(res interface{}) []byte {
return res.(*agent.ResultResponse).File
},
)
}
func (s *grpcServer) Attestation(req *agent.AttestationRequest, stream agent.AgentService_AttestationServer) error {
return s.streamingHandler(
stream.Context(),
"attestation",
req,
stream,
func(data []byte) error {
return stream.Send(&agent.AttestationResponse{File: data})
},
func(res interface{}) []byte {
return res.(*agent.AttestationResponse).File
},
)
}
func (s *grpcServer) IMAMeasurements(req *agent.IMAMeasurementsRequest, stream agent.AgentService_IMAMeasurementsServer) error {
_, res, err := s.imaMeasurements.ServeGRPC(stream.Context(), req)
_, res, err := s.handlers["imaMeasurements"].ServeGRPC(stream.Context(), req)
if err != nil {
return err
}
rr := res.(*agent.IMAMeasurementsResponse)
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: strconv.Itoa(len(rr.File))})); err != nil {
if err := stream.SetHeader(metadata.New(map[string]string{
FileSizeKey: strconv.Itoa(len(rr.File)),
})); err != nil {
return status.Error(codes.Internal, err.Error())
}
imaBuff := bytes.NewBuffer(rr.File)
pcr10Buff := bytes.NewBuffer(rr.Pcr10)
return s.streamDualBuffers(
bytes.NewBuffer(rr.File),
bytes.NewBuffer(rr.Pcr10),
func(fileData, pcr10Data []byte) error {
return stream.Send(&agent.IMAMeasurementsResponse{
File: fileData,
Pcr10: pcr10Data,
})
},
)
}
imaResBuff := make([]byte, bufferSize)
pcr10ResBuff := make([]byte, bufferSize)
func (s *grpcServer) streamDualBuffers(
buf1, buf2 *bytes.Buffer,
sendFn func([]byte, []byte) error,
) error {
buff1 := make([]byte, bufferSize)
buff2 := make([]byte, bufferSize)
for {
nIma, errIma := imaBuff.Read(imaResBuff)
if errIma != nil && errIma != io.EOF {
return status.Error(codes.Internal, errIma.Error())
n1, err1 := buf1.Read(buff1)
if err1 != nil && err1 != io.EOF {
return status.Error(codes.Internal, err1.Error())
}
nPcr, errPcr := pcr10Buff.Read(pcr10ResBuff)
if errPcr != nil && errPcr != io.EOF {
return status.Error(codes.Internal, errPcr.Error())
n2, err2 := buf2.Read(buff2)
if err2 != nil && err2 != io.EOF {
return status.Error(codes.Internal, err2.Error())
}
if nIma == 0 && errIma == io.EOF &&
nPcr == 0 && errPcr == io.EOF {
if n1 == 0 && err1 == io.EOF && n2 == 0 && err2 == io.EOF {
break
}
if err := stream.Send(&agent.IMAMeasurementsResponse{File: imaResBuff[:nIma], Pcr10: pcr10ResBuff[:nPcr]}); err != nil {
if err := sendFn(buff1[:n1], buff2[:n2]); err != nil {
return status.Error(codes.Internal, err.Error())
}
}
return nil
}
func (s *grpcServer) AttestationResult(ctx context.Context, req *agent.AttestationResultRequest) (*agent.AttestationResultResponse, error) {
_, res, err := s.attestationResult.ServeGRPC(ctx, req)
_, res, err := s.handlers["attestationResult"].ServeGRPC(ctx, req)
if err != nil {
return nil, err
}
rr, ok := res.(*agent.AttestationResultResponse)
rr, ok := res.(*agent.AttestationResultResponse)
if !ok {
return nil, status.Error(codes.Internal, "failed to cast response to FetchAttestationResultResponse")
return nil, status.Error(codes.Internal, "failed to cast response to AttestationResultResponse")
}
return rr, nil
+280 -24
View File
@@ -68,8 +68,9 @@ func (m *MockAgentService_ResultServer) Context() context.Context {
return m.ctx
}
func (m *MockAgentService_ResultServer) SetHeader(metadata.MD) error {
return nil
func (m *MockAgentService_ResultServer) SetHeader(md metadata.MD) error {
args := m.Called(md)
return args.Error(0)
}
func (m *MockAgentService_ResultServer) Send(resp *agent.ResultResponse) error {
@@ -92,8 +93,46 @@ func (m *MockAgentService_AttestationServer) Send(resp *agent.AttestationRespons
return args.Error(0)
}
func (m *MockAgentService_AttestationServer) SetHeader(metadata.MD) error {
return nil
func (m *MockAgentService_AttestationServer) SetHeader(md metadata.MD) error {
args := m.Called(md)
return args.Error(0)
}
type MockAgentService_IMAMeasurementsServer struct {
grpc.ServerStream
mock.Mock
ctx context.Context
}
func (m *MockAgentService_IMAMeasurementsServer) Context() context.Context {
return m.ctx
}
func (m *MockAgentService_IMAMeasurementsServer) Send(resp *agent.IMAMeasurementsResponse) error {
args := m.Called(resp)
return args.Error(0)
}
func (m *MockAgentService_IMAMeasurementsServer) SetHeader(md metadata.MD) error {
args := m.Called(md)
return args.Error(0)
}
func TestNewServer(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
grpcServer, ok := server.(*grpcServer)
assert.True(t, ok)
assert.NotNil(t, grpcServer.handlers)
assert.Len(t, grpcServer.handlers, 6) // Should have 6 handlers
// Check that all expected handlers are present
expectedHandlers := []string{"algo", "data", "result", "attestation", "imaMeasurements", "attestationResult"}
for _, handler := range expectedHandlers {
assert.Contains(t, grpcServer.handlers, handler)
assert.NotNil(t, grpcServer.handlers[handler])
}
}
func TestAlgo(t *testing.T) {
@@ -102,8 +141,8 @@ func TestAlgo(t *testing.T) {
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once()
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF)
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil)
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF).Once()
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil).Once()
mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo"), Requirements: []byte("req")}).Return(nil)
@@ -114,14 +153,33 @@ func TestAlgo(t *testing.T) {
mockService.AssertExpectations(t)
}
func TestAlgoWithMultipleChunks(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once()
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("2"), Requirements: []byte("2")}, nil).Once()
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF).Once()
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil).Once()
mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo2"), Requirements: []byte("req2")}).Return(nil)
err := server.Algo(mockStream)
assert.NoError(t, err)
mockStream.AssertExpectations(t)
mockService.AssertExpectations(t)
}
func TestData(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
mockStream := &MockAgentService_DataServer{ctx: context.Background()}
mockStream.On("Recv").Return(&agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"}, nil).Once()
mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF)
mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil)
mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF).Once()
mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil).Once()
mockService.On("Data", context.Background(), agent.Dataset{Dataset: []byte("data"), Filename: "test.txt"}).Return(nil)
@@ -136,9 +194,18 @@ func TestResult(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
resultData := []byte("result data")
mockStream := &MockAgentService_ResultServer{ctx: context.Background()}
mockService.On("Result", mock.Anything).Return([]byte("result data"), nil)
mockStream.On("Send", mock.AnythingOfType("*agent.ResultResponse")).Return(nil)
// Mock the SetHeader call
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
// Mock the Send call - it should be called with the result data
mockStream.On("Send", mock.MatchedBy(func(resp *agent.ResultResponse) bool {
return len(resp.File) > 0
})).Return(nil).Once()
mockService.On("Result", mock.Anything).Return(resultData, nil)
err := server.Result(&agent.ResultRequest{}, mockStream)
assert.NoError(t, err)
@@ -151,35 +218,137 @@ func TestAttestation(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
attestationData := []byte("attestation data")
mockStream := &MockAgentService_AttestationServer{ctx: context.Background()}
mockStream.On("Send", mock.AnythingOfType("*agent.AttestationResponse")).Return(nil)
// Mock the SetHeader call
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
// Mock the Send call
mockStream.On("Send", mock.MatchedBy(func(resp *agent.AttestationResponse) bool {
return len(resp.File) > 0
})).Return(nil).Once()
reportData := [quoteprovider.Nonce]byte{}
vtpmNonce := [vtpm.Nonce]byte{}
attestationType := attestation.SNP
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return([]byte("attestation data"), nil)
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return(attestationData, nil)
err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:]}, mockStream)
err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:], Type: int32(attestationType)}, mockStream)
assert.NoError(t, err)
mockService.AssertExpectations(t)
mockStream.AssertExpectations(t)
}
func TestIMAMeasurements(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
imaData := []byte("ima data")
pcr10Data := []byte("pcr10 data")
mockStream := &MockAgentService_IMAMeasurementsServer{ctx: context.Background()}
// Mock the SetHeader call
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
// Mock the Send call
mockStream.On("Send", mock.MatchedBy(func(resp *agent.IMAMeasurementsResponse) bool {
return len(resp.File) > 0 || len(resp.Pcr10) > 0
})).Return(nil).Once()
mockService.On("IMAMeasurements", mock.Anything).Return(imaData, pcr10Data, nil)
err := server.IMAMeasurements(&agent.IMAMeasurementsRequest{}, mockStream)
assert.NoError(t, err)
mockService.AssertExpectations(t)
mockStream.AssertExpectations(t)
}
func TestAttestationResult(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
attestationData := []byte("attestation result data")
vtpmNonce := [vtpm.Nonce]byte{}
attestationType := attestation.SNP
mockService.On("AttestationResult", mock.Anything, vtpmNonce, attestationType).Return([]byte("attestation data"), nil)
resp, err := server.AttestationResult(context.Background(), &agent.AttestationResultRequest{TokenNonce: vtpmNonce[:]})
mockService.On("AttestationResult", mock.Anything, vtpmNonce, attestationType).Return(attestationData, nil)
resp, err := server.AttestationResult(context.Background(), &agent.AttestationResultRequest{
TokenNonce: vtpmNonce[:],
Type: int32(attestationType),
})
assert.NoError(t, err)
assert.Equal(t, []byte("attestation data"), resp.File)
assert.Equal(t, attestationData, resp.File)
mockService.AssertExpectations(t)
}
func TestValidateNonce(t *testing.T) {
tests := []struct {
name string
nonce []byte
maxLen int
shouldError bool
expectedErr error
}{
{
name: "valid TEE nonce",
nonce: make([]byte, quoteprovider.Nonce),
maxLen: quoteprovider.Nonce,
shouldError: false,
},
{
name: "valid vTPM nonce",
nonce: make([]byte, vtpm.Nonce),
maxLen: vtpm.Nonce,
shouldError: false,
},
{
name: "TEE nonce too long",
nonce: make([]byte, quoteprovider.Nonce+1),
maxLen: quoteprovider.Nonce,
shouldError: true,
expectedErr: ErrTEENonceLength,
},
{
name: "vTPM nonce too long",
nonce: make([]byte, vtpm.Nonce+1),
maxLen: vtpm.Nonce,
shouldError: true,
expectedErr: ErrVTPMNonceLength,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.maxLen == quoteprovider.Nonce {
var target [quoteprovider.Nonce]byte
err := validateNonce(tt.nonce, tt.maxLen, &target)
if tt.shouldError {
assert.Error(t, err)
assert.Equal(t, tt.expectedErr, err)
} else {
assert.NoError(t, err)
}
} else {
var target [vtpm.Nonce]byte
err := validateNonce(tt.nonce, tt.maxLen, &target)
if tt.shouldError {
assert.Error(t, err)
assert.Equal(t, tt.expectedErr, err)
} else {
assert.NoError(t, err)
}
}
})
}
}
func TestDecodeAlgoRequest(t *testing.T) {
req := &agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}
decoded, err := decodeAlgoRequest(context.Background(), req)
@@ -219,11 +388,38 @@ func TestEncodeResultResponse(t *testing.T) {
}
func TestDecodeAttestationRequest(t *testing.T) {
nonce := [quoteprovider.Nonce]byte{}
req := &agent.AttestationRequest{TeeNonce: nonce[:]}
teeNonce := make([]byte, quoteprovider.Nonce)
vtpmNonce := make([]byte, vtpm.Nonce)
req := &agent.AttestationRequest{
TeeNonce: teeNonce,
VtpmNonce: vtpmNonce,
Type: int32(attestation.SNP),
}
decoded, err := decodeAttestationRequest(context.Background(), req)
assert.NoError(t, err)
assert.Equal(t, attestationReq{TeeNonce: nonce}, decoded)
decodedReq := decoded.(attestationReq)
assert.Equal(t, attestation.SNP, decodedReq.AttType)
}
func TestDecodeAttestationRequestWithInvalidNonce(t *testing.T) {
// Test with TEE nonce too long
teeNonce := make([]byte, quoteprovider.Nonce+1)
req := &agent.AttestationRequest{TeeNonce: teeNonce}
_, err := decodeAttestationRequest(context.Background(), req)
assert.Error(t, err)
assert.Equal(t, ErrTEENonceLength, err)
// Test with vTPM nonce too long
vtpmNonce := make([]byte, vtpm.Nonce+1)
req = &agent.AttestationRequest{VtpmNonce: vtpmNonce}
_, err = decodeAttestationRequest(context.Background(), req)
assert.Error(t, err)
assert.Equal(t, ErrVTPMNonceLength, err)
}
func TestEncodeAttestationResponse(t *testing.T) {
@@ -232,16 +428,76 @@ func TestEncodeAttestationResponse(t *testing.T) {
assert.Equal(t, &agent.AttestationResponse{File: []byte("attestation")}, encoded)
}
func TestDecodeAttestationResultRequest(t *testing.T) {
tokenNonce := make([]byte, vtpm.Nonce)
req := &agent.AttestationResultRequest{
TokenNonce: tokenNonce,
Type: int32(attestation.SNP),
}
decoded, err := decodeAttestationResultRequest(context.Background(), req)
assert.NoError(t, err)
decodedReq := decoded.(FetchAttestationResultReq)
assert.Equal(t, attestation.SNP, decodedReq.AttType)
}
func TestDecodeAttestationResultRequestWithInvalidNonce(t *testing.T) {
// Test with token nonce too long
tokenNonce := make([]byte, vtpm.Nonce+1)
req := &agent.AttestationResultRequest{TokenNonce: tokenNonce}
_, err := decodeAttestationResultRequest(context.Background(), req)
assert.Error(t, err)
assert.Equal(t, ErrVTPMNonceLength, err)
}
func TestEncodeAttestationResultResponse(t *testing.T) {
encoded, err := encodeAttestationResultResponse(context.Background(), fetchAttestationResultRes{File: []byte("attestation")})
assert.NoError(t, err)
assert.Equal(t, &agent.AttestationResultResponse{File: []byte("attestation")}, encoded)
}
func TestDecodeAttestationResultRequest(t *testing.T) {
nonce := [vtpm.Nonce]byte{}
req := &agent.AttestationResultRequest{TokenNonce: nonce[:]}
decoded, err := decodeAttestationResultRequest(context.Background(), req)
func TestDecodeIMAMeasurementsRequest(t *testing.T) {
decoded, err := decodeIMAMeasurementsRequest(context.Background(), &agent.IMAMeasurementsRequest{})
assert.NoError(t, err)
assert.Equal(t, FetchAttestationResultReq{tokenNonce: nonce}, decoded)
assert.Equal(t, imaMeasurementsReq{}, decoded)
}
func TestEncodeIMAMeasurementsResponse(t *testing.T) {
encoded, err := encodeIMAMeasurementsResponse(context.Background(), imaMeasurementsRes{
File: []byte("ima"),
PCR10: []byte("pcr10"),
})
assert.NoError(t, err)
assert.Equal(t, &agent.IMAMeasurementsResponse{
File: []byte("ima"),
Pcr10: []byte("pcr10"),
}, encoded)
}
func TestAlgoWithStreamError(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
mockStream.On("Recv").Return(&agent.AlgoRequest{}, assert.AnError).Once()
err := server.Algo(mockStream)
assert.Error(t, err)
mockStream.AssertExpectations(t)
}
func TestDataWithStreamError(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
mockStream := &MockAgentService_DataServer{ctx: context.Background()}
mockStream.On("Recv").Return(&agent.DataRequest{}, assert.AnError).Once()
err := server.Data(mockStream)
assert.Error(t, err)
mockStream.AssertExpectations(t)
}