mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-430 - Refactor gRPC server handlers to use a map for improved organization and add validation for nonce lengths in attestation requests (#477)
* Refactor gRPC server handlers to use a map for improved organization and add validation for nonce lengths in attestation requests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance nonce validation and improve error handling in gRPC server methods Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
45187d7f41
commit
85a2b7a6c8
+290
-185
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
"github.com/go-kit/kit/transport/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
@@ -34,54 +35,69 @@ var (
|
||||
var _ agent.AgentServiceServer = (*grpcServer)(nil)
|
||||
|
||||
type grpcServer struct {
|
||||
algo grpc.Handler
|
||||
data grpc.Handler
|
||||
result grpc.Handler
|
||||
attestation grpc.Handler
|
||||
imaMeasurements grpc.Handler
|
||||
attestationResult grpc.Handler
|
||||
handlers map[string]grpc.Handler
|
||||
agent.UnimplementedAgentServiceServer
|
||||
}
|
||||
|
||||
type endpointConfig struct {
|
||||
endpoint func(agent.Service) endpoint.Endpoint
|
||||
decodeRequest grpc.DecodeRequestFunc
|
||||
encodeResponse grpc.EncodeResponseFunc
|
||||
}
|
||||
|
||||
// NewServer returns new AgentServiceServer instance.
|
||||
func NewServer(svc agent.Service) agent.AgentServiceServer {
|
||||
// Define endpoint configurations
|
||||
endpoints := map[string]endpointConfig{
|
||||
"algo": {
|
||||
endpoint: algoEndpoint,
|
||||
decodeRequest: decodeAlgoRequest,
|
||||
encodeResponse: encodeAlgoResponse,
|
||||
},
|
||||
"data": {
|
||||
endpoint: dataEndpoint,
|
||||
decodeRequest: decodeDataRequest,
|
||||
encodeResponse: encodeDataResponse,
|
||||
},
|
||||
"result": {
|
||||
endpoint: resultEndpoint,
|
||||
decodeRequest: decodeResultRequest,
|
||||
encodeResponse: encodeResultResponse,
|
||||
},
|
||||
"attestation": {
|
||||
endpoint: attestationEndpoint,
|
||||
decodeRequest: decodeAttestationRequest,
|
||||
encodeResponse: encodeAttestationResponse,
|
||||
},
|
||||
"imaMeasurements": {
|
||||
endpoint: imaMeasurementsEndpoint,
|
||||
decodeRequest: decodeIMAMeasurementsRequest,
|
||||
encodeResponse: encodeIMAMeasurementsResponse,
|
||||
},
|
||||
"attestationResult": {
|
||||
endpoint: attestationResultEndpoint,
|
||||
decodeRequest: decodeAttestationResultRequest,
|
||||
encodeResponse: encodeAttestationResultResponse,
|
||||
},
|
||||
}
|
||||
|
||||
// Create handlers using the configurations
|
||||
handlers := make(map[string]grpc.Handler)
|
||||
for name, config := range endpoints {
|
||||
handlers[name] = grpc.NewServer(
|
||||
config.endpoint(svc),
|
||||
config.decodeRequest,
|
||||
config.encodeResponse,
|
||||
)
|
||||
}
|
||||
|
||||
return &grpcServer{
|
||||
algo: grpc.NewServer(
|
||||
algoEndpoint(svc),
|
||||
decodeAlgoRequest,
|
||||
encodeAlgoResponse,
|
||||
),
|
||||
data: grpc.NewServer(
|
||||
dataEndpoint(svc),
|
||||
decodeDataRequest,
|
||||
encodeDataResponse,
|
||||
),
|
||||
result: grpc.NewServer(
|
||||
resultEndpoint(svc),
|
||||
decodeResultRequest,
|
||||
encodeResultResponse,
|
||||
),
|
||||
attestation: grpc.NewServer(
|
||||
attestationEndpoint(svc),
|
||||
decodeAttestationRequest,
|
||||
encodeAttestationResponse,
|
||||
),
|
||||
imaMeasurements: grpc.NewServer(
|
||||
imaMeasurementsEndpoint(svc),
|
||||
decodeIMAMeasurementsRequest,
|
||||
encodeIMAMeasurementsResponse,
|
||||
),
|
||||
attestationResult: grpc.NewServer(
|
||||
attestationResultEndpoint(svc),
|
||||
decodeAttestationResultRequest,
|
||||
encodeAttestationResultResponse,
|
||||
),
|
||||
handlers: handlers,
|
||||
}
|
||||
}
|
||||
|
||||
func decodeAlgoRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(*agent.AlgoRequest)
|
||||
|
||||
return algoReq{
|
||||
Algorithm: req.Algorithm,
|
||||
Requirements: req.Requirements,
|
||||
@@ -94,7 +110,6 @@ func encodeAlgoResponse(_ context.Context, response interface{}) (interface{}, e
|
||||
|
||||
func decodeDataRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(*agent.DataRequest)
|
||||
|
||||
return dataReq{
|
||||
Dataset: req.Dataset,
|
||||
Filename: req.Filename,
|
||||
@@ -116,22 +131,47 @@ func encodeResultResponse(_ context.Context, response interface{}) (interface{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateNonce(nonce []byte, maxLen int, target interface{}) error {
|
||||
if len(nonce) > maxLen {
|
||||
switch maxLen {
|
||||
case quoteprovider.Nonce:
|
||||
return ErrTEENonceLength
|
||||
case vtpm.Nonce:
|
||||
return ErrVTPMNonceLength
|
||||
default:
|
||||
return ErrTokenNonceLength
|
||||
}
|
||||
}
|
||||
|
||||
switch t := target.(type) {
|
||||
case *[quoteprovider.Nonce]byte:
|
||||
copy(t[:], nonce)
|
||||
case *[vtpm.Nonce]byte:
|
||||
copy(t[:], nonce)
|
||||
default:
|
||||
return fmt.Errorf("unsupported target type for nonce validation: %T", target)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeAttestationRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(*agent.AttestationRequest)
|
||||
var reportData [quoteprovider.Nonce]byte
|
||||
var nonce [vtpm.Nonce]byte
|
||||
|
||||
if len(req.TeeNonce) > quoteprovider.Nonce {
|
||||
return nil, ErrTEENonceLength
|
||||
if err := validateNonce(req.TeeNonce, quoteprovider.Nonce, &reportData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(req.VtpmNonce) > vtpm.Nonce {
|
||||
return nil, ErrVTPMNonceLength
|
||||
if err := validateNonce(req.VtpmNonce, vtpm.Nonce, &nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
copy(reportData[:], req.TeeNonce)
|
||||
copy(nonce[:], req.VtpmNonce)
|
||||
return attestationReq{TeeNonce: reportData, VtpmNonce: nonce, AttType: attestation.PlatformType(req.Type)}, nil
|
||||
return attestationReq{
|
||||
TeeNonce: reportData,
|
||||
VtpmNonce: nonce,
|
||||
AttType: attestation.PlatformType(req.Type),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeAttestationResponse(_ context.Context, response interface{}) (interface{}, error) {
|
||||
@@ -141,6 +181,20 @@ func encodeAttestationResponse(_ context.Context, response interface{}) (interfa
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeAttestationResultRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(*agent.AttestationResultRequest)
|
||||
var nonce [vtpm.Nonce]byte
|
||||
|
||||
if err := validateNonce(req.TokenNonce, vtpm.Nonce, &nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return FetchAttestationResultReq{
|
||||
tokenNonce: nonce,
|
||||
AttType: attestation.PlatformType(req.Type),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeAttestationResultResponse(_ context.Context, response interface{}) (interface{}, error) {
|
||||
res := response.(fetchAttestationResultRes)
|
||||
return &agent.AttestationResultResponse{
|
||||
@@ -148,127 +202,6 @@ func encodeAttestationResultResponse(_ context.Context, response interface{}) (i
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeAttestationResultRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(*agent.AttestationResultRequest)
|
||||
var nonce [vtpm.Nonce]byte
|
||||
|
||||
if len(req.TokenNonce) > vtpm.Nonce {
|
||||
return nil, ErrVTPMNonceLength
|
||||
}
|
||||
|
||||
copy(nonce[:], req.TokenNonce)
|
||||
return FetchAttestationResultReq{tokenNonce: nonce, AttType: attestation.PlatformType(req.Type)}, nil
|
||||
}
|
||||
|
||||
// Algo implements agent.AgentServiceServer.
|
||||
func (s *grpcServer) Algo(stream agent.AgentService_AlgoServer) error {
|
||||
var algoFile, reqFile []byte
|
||||
for {
|
||||
algoChunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
algoFile = append(algoFile, algoChunk.Algorithm...)
|
||||
reqFile = append(reqFile, algoChunk.Requirements...)
|
||||
}
|
||||
_, res, err := s.algo.ServeGRPC(stream.Context(), &agent.AlgoRequest{Algorithm: algoFile, Requirements: reqFile})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ar := res.(*agent.AlgoResponse)
|
||||
return stream.SendAndClose(ar)
|
||||
}
|
||||
|
||||
// Data implements agent.AgentServiceServer.
|
||||
func (s *grpcServer) Data(stream agent.AgentService_DataServer) error {
|
||||
var dataFile []byte
|
||||
var filename string
|
||||
for {
|
||||
dataChunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
dataFile = append(dataFile, dataChunk.Dataset...)
|
||||
filename = dataChunk.Filename
|
||||
}
|
||||
_, res, err := s.data.ServeGRPC(stream.Context(), &agent.DataRequest{Dataset: dataFile, Filename: filename})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ar := res.(*agent.DataResponse)
|
||||
return stream.SendAndClose(ar)
|
||||
}
|
||||
|
||||
func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_ResultServer) error {
|
||||
_, res, err := s.result.ServeGRPC(stream.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rr := res.(*agent.ResultResponse)
|
||||
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
resultBuffer := bytes.NewBuffer(rr.File)
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := resultBuffer.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
if err := stream.Send(&agent.ResultResponse{File: buf[:n]}); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) Attestation(req *agent.AttestationRequest, stream agent.AgentService_AttestationServer) error {
|
||||
_, res, err := s.attestation.ServeGRPC(stream.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rr := res.(*agent.AttestationResponse)
|
||||
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
attestationBuffer := bytes.NewBuffer(rr.File)
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := attestationBuffer.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
if err := stream.Send(&agent.AttestationResponse{File: buf[:n]}); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeIMAMeasurementsRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
return imaMeasurementsReq{}, nil
|
||||
}
|
||||
@@ -281,56 +214,228 @@ func encodeIMAMeasurementsResponse(_ context.Context, response interface{}) (int
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) streamingHandler(
|
||||
ctx context.Context,
|
||||
handlerName string,
|
||||
req interface{},
|
||||
stream interface{},
|
||||
sendFn func([]byte) error,
|
||||
getFileData func(interface{}) []byte,
|
||||
) error {
|
||||
handler, ok := s.handlers[handlerName]
|
||||
if !ok {
|
||||
return status.Errorf(codes.NotFound, "handler %q not found", handlerName)
|
||||
}
|
||||
|
||||
_, res, err := handler.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileData := getFileData(res)
|
||||
|
||||
// Set file size header
|
||||
if setter, ok := stream.(interface{ SetHeader(metadata.MD) error }); ok {
|
||||
if err := setter.SetHeader(metadata.New(map[string]string{
|
||||
FileSizeKey: fmt.Sprint(len(fileData)),
|
||||
})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the file data
|
||||
return s.streamFileData(bytes.NewBuffer(fileData), sendFn)
|
||||
}
|
||||
|
||||
func (s *grpcServer) streamFileData(buffer *bytes.Buffer, sendFn func([]byte) error) error {
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
n, err := buffer.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
if err := sendFn(buf[:n]); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func receiveStreamingData(getData func() ([]byte, string, error)) ([]byte, string, error) {
|
||||
var data []byte
|
||||
var filename string
|
||||
|
||||
for {
|
||||
chunk, fname, err := getData()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
data = append(data, chunk...)
|
||||
if fname != "" {
|
||||
filename = fname
|
||||
}
|
||||
}
|
||||
return data, filename, nil
|
||||
}
|
||||
|
||||
// Algo implements agent.AgentServiceServer.
|
||||
func (s *grpcServer) Algo(stream agent.AgentService_AlgoServer) error {
|
||||
algoFile, reqFile, err := s.receiveAlgoData(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, res, err := s.handlers["algo"].ServeGRPC(stream.Context(), &agent.AlgoRequest{
|
||||
Algorithm: algoFile,
|
||||
Requirements: reqFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return stream.SendAndClose(res.(*agent.AlgoResponse))
|
||||
}
|
||||
|
||||
func (s *grpcServer) receiveAlgoData(stream agent.AgentService_AlgoServer) ([]byte, []byte, error) {
|
||||
var algoFile, reqFile []byte
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
algoFile = append(algoFile, chunk.Algorithm...)
|
||||
reqFile = append(reqFile, chunk.Requirements...)
|
||||
}
|
||||
return algoFile, reqFile, nil
|
||||
}
|
||||
|
||||
// Data implements agent.AgentServiceServer.
|
||||
func (s *grpcServer) Data(stream agent.AgentService_DataServer) error {
|
||||
dataFile, filename, err := receiveStreamingData(func() ([]byte, string, error) {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return chunk.Dataset, chunk.Filename, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, res, err := s.handlers["data"].ServeGRPC(stream.Context(), &agent.DataRequest{
|
||||
Dataset: dataFile,
|
||||
Filename: filename,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return stream.SendAndClose(res.(*agent.DataResponse))
|
||||
}
|
||||
|
||||
func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_ResultServer) error {
|
||||
return s.streamingHandler(
|
||||
stream.Context(),
|
||||
"result",
|
||||
req,
|
||||
stream,
|
||||
func(data []byte) error {
|
||||
return stream.Send(&agent.ResultResponse{File: data})
|
||||
},
|
||||
func(res interface{}) []byte {
|
||||
return res.(*agent.ResultResponse).File
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *grpcServer) Attestation(req *agent.AttestationRequest, stream agent.AgentService_AttestationServer) error {
|
||||
return s.streamingHandler(
|
||||
stream.Context(),
|
||||
"attestation",
|
||||
req,
|
||||
stream,
|
||||
func(data []byte) error {
|
||||
return stream.Send(&agent.AttestationResponse{File: data})
|
||||
},
|
||||
func(res interface{}) []byte {
|
||||
return res.(*agent.AttestationResponse).File
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *grpcServer) IMAMeasurements(req *agent.IMAMeasurementsRequest, stream agent.AgentService_IMAMeasurementsServer) error {
|
||||
_, res, err := s.imaMeasurements.ServeGRPC(stream.Context(), req)
|
||||
_, res, err := s.handlers["imaMeasurements"].ServeGRPC(stream.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rr := res.(*agent.IMAMeasurementsResponse)
|
||||
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: strconv.Itoa(len(rr.File))})); err != nil {
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{
|
||||
FileSizeKey: strconv.Itoa(len(rr.File)),
|
||||
})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
imaBuff := bytes.NewBuffer(rr.File)
|
||||
pcr10Buff := bytes.NewBuffer(rr.Pcr10)
|
||||
return s.streamDualBuffers(
|
||||
bytes.NewBuffer(rr.File),
|
||||
bytes.NewBuffer(rr.Pcr10),
|
||||
func(fileData, pcr10Data []byte) error {
|
||||
return stream.Send(&agent.IMAMeasurementsResponse{
|
||||
File: fileData,
|
||||
Pcr10: pcr10Data,
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
imaResBuff := make([]byte, bufferSize)
|
||||
pcr10ResBuff := make([]byte, bufferSize)
|
||||
func (s *grpcServer) streamDualBuffers(
|
||||
buf1, buf2 *bytes.Buffer,
|
||||
sendFn func([]byte, []byte) error,
|
||||
) error {
|
||||
buff1 := make([]byte, bufferSize)
|
||||
buff2 := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
nIma, errIma := imaBuff.Read(imaResBuff)
|
||||
if errIma != nil && errIma != io.EOF {
|
||||
return status.Error(codes.Internal, errIma.Error())
|
||||
n1, err1 := buf1.Read(buff1)
|
||||
if err1 != nil && err1 != io.EOF {
|
||||
return status.Error(codes.Internal, err1.Error())
|
||||
}
|
||||
|
||||
nPcr, errPcr := pcr10Buff.Read(pcr10ResBuff)
|
||||
if errPcr != nil && errPcr != io.EOF {
|
||||
return status.Error(codes.Internal, errPcr.Error())
|
||||
n2, err2 := buf2.Read(buff2)
|
||||
if err2 != nil && err2 != io.EOF {
|
||||
return status.Error(codes.Internal, err2.Error())
|
||||
}
|
||||
|
||||
if nIma == 0 && errIma == io.EOF &&
|
||||
nPcr == 0 && errPcr == io.EOF {
|
||||
if n1 == 0 && err1 == io.EOF && n2 == 0 && err2 == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if err := stream.Send(&agent.IMAMeasurementsResponse{File: imaResBuff[:nIma], Pcr10: pcr10ResBuff[:nPcr]}); err != nil {
|
||||
if err := sendFn(buff1[:n1], buff2[:n2]); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) AttestationResult(ctx context.Context, req *agent.AttestationResultRequest) (*agent.AttestationResultResponse, error) {
|
||||
_, res, err := s.attestationResult.ServeGRPC(ctx, req)
|
||||
_, res, err := s.handlers["attestationResult"].ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rr, ok := res.(*agent.AttestationResultResponse)
|
||||
|
||||
rr, ok := res.(*agent.AttestationResultResponse)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "failed to cast response to FetchAttestationResultResponse")
|
||||
return nil, status.Error(codes.Internal, "failed to cast response to AttestationResultResponse")
|
||||
}
|
||||
|
||||
return rr, nil
|
||||
|
||||
+280
-24
@@ -68,8 +68,9 @@ func (m *MockAgentService_ResultServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_ResultServer) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
func (m *MockAgentService_ResultServer) SetHeader(md metadata.MD) error {
|
||||
args := m.Called(md)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_ResultServer) Send(resp *agent.ResultResponse) error {
|
||||
@@ -92,8 +93,46 @@ func (m *MockAgentService_AttestationServer) Send(resp *agent.AttestationRespons
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_AttestationServer) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
func (m *MockAgentService_AttestationServer) SetHeader(md metadata.MD) error {
|
||||
args := m.Called(md)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockAgentService_IMAMeasurementsServer struct {
|
||||
grpc.ServerStream
|
||||
mock.Mock
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *MockAgentService_IMAMeasurementsServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_IMAMeasurementsServer) Send(resp *agent.IMAMeasurementsResponse) error {
|
||||
args := m.Called(resp)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_IMAMeasurementsServer) SetHeader(md metadata.MD) error {
|
||||
args := m.Called(md)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
grpcServer, ok := server.(*grpcServer)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, grpcServer.handlers)
|
||||
assert.Len(t, grpcServer.handlers, 6) // Should have 6 handlers
|
||||
|
||||
// Check that all expected handlers are present
|
||||
expectedHandlers := []string{"algo", "data", "result", "attestation", "imaMeasurements", "attestationResult"}
|
||||
for _, handler := range expectedHandlers {
|
||||
assert.Contains(t, grpcServer.handlers, handler)
|
||||
assert.NotNil(t, grpcServer.handlers[handler])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlgo(t *testing.T) {
|
||||
@@ -102,8 +141,8 @@ func TestAlgo(t *testing.T) {
|
||||
|
||||
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF)
|
||||
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil)
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF).Once()
|
||||
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil).Once()
|
||||
|
||||
mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo"), Requirements: []byte("req")}).Return(nil)
|
||||
|
||||
@@ -114,14 +153,33 @@ func TestAlgo(t *testing.T) {
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAlgoWithMultipleChunks(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("2"), Requirements: []byte("2")}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF).Once()
|
||||
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil).Once()
|
||||
|
||||
mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo2"), Requirements: []byte("req2")}).Return(nil)
|
||||
|
||||
err := server.Algo(mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_DataServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF)
|
||||
mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil)
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF).Once()
|
||||
mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil).Once()
|
||||
|
||||
mockService.On("Data", context.Background(), agent.Dataset{Dataset: []byte("data"), Filename: "test.txt"}).Return(nil)
|
||||
|
||||
@@ -136,9 +194,18 @@ func TestResult(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
resultData := []byte("result data")
|
||||
mockStream := &MockAgentService_ResultServer{ctx: context.Background()}
|
||||
mockService.On("Result", mock.Anything).Return([]byte("result data"), nil)
|
||||
mockStream.On("Send", mock.AnythingOfType("*agent.ResultResponse")).Return(nil)
|
||||
|
||||
// Mock the SetHeader call
|
||||
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
|
||||
|
||||
// Mock the Send call - it should be called with the result data
|
||||
mockStream.On("Send", mock.MatchedBy(func(resp *agent.ResultResponse) bool {
|
||||
return len(resp.File) > 0
|
||||
})).Return(nil).Once()
|
||||
|
||||
mockService.On("Result", mock.Anything).Return(resultData, nil)
|
||||
|
||||
err := server.Result(&agent.ResultRequest{}, mockStream)
|
||||
assert.NoError(t, err)
|
||||
@@ -151,35 +218,137 @@ func TestAttestation(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
attestationData := []byte("attestation data")
|
||||
mockStream := &MockAgentService_AttestationServer{ctx: context.Background()}
|
||||
mockStream.On("Send", mock.AnythingOfType("*agent.AttestationResponse")).Return(nil)
|
||||
|
||||
// Mock the SetHeader call
|
||||
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
|
||||
|
||||
// Mock the Send call
|
||||
mockStream.On("Send", mock.MatchedBy(func(resp *agent.AttestationResponse) bool {
|
||||
return len(resp.File) > 0
|
||||
})).Return(nil).Once()
|
||||
|
||||
reportData := [quoteprovider.Nonce]byte{}
|
||||
vtpmNonce := [vtpm.Nonce]byte{}
|
||||
attestationType := attestation.SNP
|
||||
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return([]byte("attestation data"), nil)
|
||||
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return(attestationData, nil)
|
||||
|
||||
err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:]}, mockStream)
|
||||
err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:], Type: int32(attestationType)}, mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockService.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestIMAMeasurements(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
imaData := []byte("ima data")
|
||||
pcr10Data := []byte("pcr10 data")
|
||||
|
||||
mockStream := &MockAgentService_IMAMeasurementsServer{ctx: context.Background()}
|
||||
|
||||
// Mock the SetHeader call
|
||||
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
|
||||
|
||||
// Mock the Send call
|
||||
mockStream.On("Send", mock.MatchedBy(func(resp *agent.IMAMeasurementsResponse) bool {
|
||||
return len(resp.File) > 0 || len(resp.Pcr10) > 0
|
||||
})).Return(nil).Once()
|
||||
|
||||
mockService.On("IMAMeasurements", mock.Anything).Return(imaData, pcr10Data, nil)
|
||||
|
||||
err := server.IMAMeasurements(&agent.IMAMeasurementsRequest{}, mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockService.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAttestationResult(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
attestationData := []byte("attestation result data")
|
||||
vtpmNonce := [vtpm.Nonce]byte{}
|
||||
attestationType := attestation.SNP
|
||||
mockService.On("AttestationResult", mock.Anything, vtpmNonce, attestationType).Return([]byte("attestation data"), nil)
|
||||
|
||||
resp, err := server.AttestationResult(context.Background(), &agent.AttestationResultRequest{TokenNonce: vtpmNonce[:]})
|
||||
mockService.On("AttestationResult", mock.Anything, vtpmNonce, attestationType).Return(attestationData, nil)
|
||||
|
||||
resp, err := server.AttestationResult(context.Background(), &agent.AttestationResultRequest{
|
||||
TokenNonce: vtpmNonce[:],
|
||||
Type: int32(attestationType),
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("attestation data"), resp.File)
|
||||
assert.Equal(t, attestationData, resp.File)
|
||||
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestValidateNonce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nonce []byte
|
||||
maxLen int
|
||||
shouldError bool
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "valid TEE nonce",
|
||||
nonce: make([]byte, quoteprovider.Nonce),
|
||||
maxLen: quoteprovider.Nonce,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid vTPM nonce",
|
||||
nonce: make([]byte, vtpm.Nonce),
|
||||
maxLen: vtpm.Nonce,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "TEE nonce too long",
|
||||
nonce: make([]byte, quoteprovider.Nonce+1),
|
||||
maxLen: quoteprovider.Nonce,
|
||||
shouldError: true,
|
||||
expectedErr: ErrTEENonceLength,
|
||||
},
|
||||
{
|
||||
name: "vTPM nonce too long",
|
||||
nonce: make([]byte, vtpm.Nonce+1),
|
||||
maxLen: vtpm.Nonce,
|
||||
shouldError: true,
|
||||
expectedErr: ErrVTPMNonceLength,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.maxLen == quoteprovider.Nonce {
|
||||
var target [quoteprovider.Nonce]byte
|
||||
err := validateNonce(tt.nonce, tt.maxLen, &target)
|
||||
if tt.shouldError {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.expectedErr, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
} else {
|
||||
var target [vtpm.Nonce]byte
|
||||
err := validateNonce(tt.nonce, tt.maxLen, &target)
|
||||
if tt.shouldError {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.expectedErr, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeAlgoRequest(t *testing.T) {
|
||||
req := &agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}
|
||||
decoded, err := decodeAlgoRequest(context.Background(), req)
|
||||
@@ -219,11 +388,38 @@ func TestEncodeResultResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDecodeAttestationRequest(t *testing.T) {
|
||||
nonce := [quoteprovider.Nonce]byte{}
|
||||
req := &agent.AttestationRequest{TeeNonce: nonce[:]}
|
||||
teeNonce := make([]byte, quoteprovider.Nonce)
|
||||
vtpmNonce := make([]byte, vtpm.Nonce)
|
||||
|
||||
req := &agent.AttestationRequest{
|
||||
TeeNonce: teeNonce,
|
||||
VtpmNonce: vtpmNonce,
|
||||
Type: int32(attestation.SNP),
|
||||
}
|
||||
|
||||
decoded, err := decodeAttestationRequest(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, attestationReq{TeeNonce: nonce}, decoded)
|
||||
|
||||
decodedReq := decoded.(attestationReq)
|
||||
assert.Equal(t, attestation.SNP, decodedReq.AttType)
|
||||
}
|
||||
|
||||
func TestDecodeAttestationRequestWithInvalidNonce(t *testing.T) {
|
||||
// Test with TEE nonce too long
|
||||
teeNonce := make([]byte, quoteprovider.Nonce+1)
|
||||
req := &agent.AttestationRequest{TeeNonce: teeNonce}
|
||||
|
||||
_, err := decodeAttestationRequest(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrTEENonceLength, err)
|
||||
|
||||
// Test with vTPM nonce too long
|
||||
vtpmNonce := make([]byte, vtpm.Nonce+1)
|
||||
req = &agent.AttestationRequest{VtpmNonce: vtpmNonce}
|
||||
|
||||
_, err = decodeAttestationRequest(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrVTPMNonceLength, err)
|
||||
}
|
||||
|
||||
func TestEncodeAttestationResponse(t *testing.T) {
|
||||
@@ -232,16 +428,76 @@ func TestEncodeAttestationResponse(t *testing.T) {
|
||||
assert.Equal(t, &agent.AttestationResponse{File: []byte("attestation")}, encoded)
|
||||
}
|
||||
|
||||
func TestDecodeAttestationResultRequest(t *testing.T) {
|
||||
tokenNonce := make([]byte, vtpm.Nonce)
|
||||
req := &agent.AttestationResultRequest{
|
||||
TokenNonce: tokenNonce,
|
||||
Type: int32(attestation.SNP),
|
||||
}
|
||||
|
||||
decoded, err := decodeAttestationResultRequest(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
decodedReq := decoded.(FetchAttestationResultReq)
|
||||
assert.Equal(t, attestation.SNP, decodedReq.AttType)
|
||||
}
|
||||
|
||||
func TestDecodeAttestationResultRequestWithInvalidNonce(t *testing.T) {
|
||||
// Test with token nonce too long
|
||||
tokenNonce := make([]byte, vtpm.Nonce+1)
|
||||
req := &agent.AttestationResultRequest{TokenNonce: tokenNonce}
|
||||
|
||||
_, err := decodeAttestationResultRequest(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrVTPMNonceLength, err)
|
||||
}
|
||||
|
||||
func TestEncodeAttestationResultResponse(t *testing.T) {
|
||||
encoded, err := encodeAttestationResultResponse(context.Background(), fetchAttestationResultRes{File: []byte("attestation")})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.AttestationResultResponse{File: []byte("attestation")}, encoded)
|
||||
}
|
||||
|
||||
func TestDecodeAttestationResultRequest(t *testing.T) {
|
||||
nonce := [vtpm.Nonce]byte{}
|
||||
req := &agent.AttestationResultRequest{TokenNonce: nonce[:]}
|
||||
decoded, err := decodeAttestationResultRequest(context.Background(), req)
|
||||
func TestDecodeIMAMeasurementsRequest(t *testing.T) {
|
||||
decoded, err := decodeIMAMeasurementsRequest(context.Background(), &agent.IMAMeasurementsRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, FetchAttestationResultReq{tokenNonce: nonce}, decoded)
|
||||
assert.Equal(t, imaMeasurementsReq{}, decoded)
|
||||
}
|
||||
|
||||
func TestEncodeIMAMeasurementsResponse(t *testing.T) {
|
||||
encoded, err := encodeIMAMeasurementsResponse(context.Background(), imaMeasurementsRes{
|
||||
File: []byte("ima"),
|
||||
PCR10: []byte("pcr10"),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.IMAMeasurementsResponse{
|
||||
File: []byte("ima"),
|
||||
Pcr10: []byte("pcr10"),
|
||||
}, encoded)
|
||||
}
|
||||
|
||||
func TestAlgoWithStreamError(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, assert.AnError).Once()
|
||||
|
||||
err := server.Algo(mockStream)
|
||||
assert.Error(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDataWithStreamError(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_DataServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{}, assert.AnError).Once()
|
||||
|
||||
err := server.Data(mockStream)
|
||||
assert.Error(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user