mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-256 - Progress bar on downloads (#290)
* add progress bar for downloads Signed-off-by: Sammy Oina <sammyoina@gmail.com> * better error handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test and refactor Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix failing test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add test coverage Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
765513b387
commit
6043ad150b
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Set up protoc
|
||||
run: |
|
||||
PROTOC_VERSION=27.3
|
||||
PROTOC_VERSION=28.1
|
||||
PROTOC_GEN_VERSION=v1.34.2
|
||||
PROTOC_GRPC_VERSION=v1.4.0
|
||||
|
||||
|
||||
+1
-2
@@ -1,6 +1,5 @@
|
||||
build
|
||||
|
||||
build
|
||||
cmd/manager/img
|
||||
|
||||
.cov
|
||||
@@ -8,7 +7,7 @@ cmd/manager/img
|
||||
*.pem
|
||||
|
||||
dist/
|
||||
results.zip
|
||||
*.zip
|
||||
*.spec
|
||||
*.tar
|
||||
|
||||
|
||||
+5
-5
@@ -4,7 +4,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.34.2
|
||||
// protoc v5.27.3
|
||||
// protoc v5.28.1
|
||||
// source: agent/agent.proto
|
||||
|
||||
package agent
|
||||
@@ -413,7 +413,7 @@ var file_agent_agent_proto_rawDesc = []byte{
|
||||
0x28, 0x0c, 0x52, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x44, 0x61, 0x74, 0x61, 0x22, 0x29,
|
||||
0x0a, 0x13, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73,
|
||||
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfb, 0x01, 0x0a, 0x0c, 0x41, 0x67,
|
||||
0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67,
|
||||
0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c,
|
||||
0x67, 0x6f, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41,
|
||||
@@ -425,12 +425,12 @@ var file_agent_agent_proto_rawDesc = []byte{
|
||||
0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73,
|
||||
0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12,
|
||||
0x46, 0x0a, 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19,
|
||||
0x48, 0x0a, 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19,
|
||||
0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69,
|
||||
0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e,
|
||||
0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73,
|
||||
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, 0x67, 0x65,
|
||||
0x6e, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61,
|
||||
0x67, 0x65, 0x6e, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
+1
-1
@@ -11,7 +11,7 @@ service AgentService {
|
||||
rpc Algo(stream AlgoRequest) returns (AlgoResponse) {}
|
||||
rpc Data(stream DataRequest) returns (DataResponse) {}
|
||||
rpc Result(ResultRequest) returns (stream ResultResponse) {}
|
||||
rpc Attestation(AttestationRequest) returns (AttestationResponse) {}
|
||||
rpc Attestation(AttestationRequest) returns (stream AttestationResponse) {}
|
||||
}
|
||||
|
||||
message AlgoRequest {
|
||||
|
||||
+56
-30
@@ -4,7 +4,7 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.4.0
|
||||
// - protoc v5.27.3
|
||||
// - protoc v5.28.1
|
||||
// source: agent/agent.proto
|
||||
|
||||
package agent
|
||||
@@ -35,7 +35,7 @@ type AgentServiceClient interface {
|
||||
Algo(ctx context.Context, opts ...grpc.CallOption) (AgentService_AlgoClient, error)
|
||||
Data(ctx context.Context, opts ...grpc.CallOption) (AgentService_DataClient, error)
|
||||
Result(ctx context.Context, in *ResultRequest, opts ...grpc.CallOption) (AgentService_ResultClient, error)
|
||||
Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (*AttestationResponse, error)
|
||||
Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (AgentService_AttestationClient, error)
|
||||
}
|
||||
|
||||
type agentServiceClient struct {
|
||||
@@ -149,14 +149,37 @@ func (x *agentServiceResultClient) Recv() (*ResultResponse, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (c *agentServiceClient) Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (*AttestationResponse, error) {
|
||||
func (c *agentServiceClient) Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (AgentService_AttestationClient, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(AttestationResponse)
|
||||
err := c.cc.Invoke(ctx, AgentService_Attestation_FullMethodName, in, out, cOpts...)
|
||||
stream, err := c.cc.NewStream(ctx, &AgentService_ServiceDesc.Streams[3], AgentService_Attestation_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
x := &agentServiceAttestationClient{ClientStream: stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type AgentService_AttestationClient interface {
|
||||
Recv() (*AttestationResponse, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type agentServiceAttestationClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *agentServiceAttestationClient) Recv() (*AttestationResponse, error) {
|
||||
m := new(AttestationResponse)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AgentServiceServer is the server API for AgentService service.
|
||||
@@ -166,7 +189,7 @@ type AgentServiceServer interface {
|
||||
Algo(AgentService_AlgoServer) error
|
||||
Data(AgentService_DataServer) error
|
||||
Result(*ResultRequest, AgentService_ResultServer) error
|
||||
Attestation(context.Context, *AttestationRequest) (*AttestationResponse, error)
|
||||
Attestation(*AttestationRequest, AgentService_AttestationServer) error
|
||||
mustEmbedUnimplementedAgentServiceServer()
|
||||
}
|
||||
|
||||
@@ -183,8 +206,8 @@ func (UnimplementedAgentServiceServer) Data(AgentService_DataServer) error {
|
||||
func (UnimplementedAgentServiceServer) Result(*ResultRequest, AgentService_ResultServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Result not implemented")
|
||||
}
|
||||
func (UnimplementedAgentServiceServer) Attestation(context.Context, *AttestationRequest) (*AttestationResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Attestation not implemented")
|
||||
func (UnimplementedAgentServiceServer) Attestation(*AttestationRequest, AgentService_AttestationServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Attestation not implemented")
|
||||
}
|
||||
func (UnimplementedAgentServiceServer) mustEmbedUnimplementedAgentServiceServer() {}
|
||||
|
||||
@@ -272,22 +295,25 @@ func (x *agentServiceResultServer) Send(m *ResultResponse) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func _AgentService_Attestation_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(AttestationRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
func _AgentService_Attestation_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(AttestationRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(AgentServiceServer).Attestation(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: AgentService_Attestation_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(AgentServiceServer).Attestation(ctx, req.(*AttestationRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
return srv.(AgentServiceServer).Attestation(m, &agentServiceAttestationServer{ServerStream: stream})
|
||||
}
|
||||
|
||||
type AgentService_AttestationServer interface {
|
||||
Send(*AttestationResponse) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type agentServiceAttestationServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *agentServiceAttestationServer) Send(m *AttestationResponse) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
// AgentService_ServiceDesc is the grpc.ServiceDesc for AgentService service.
|
||||
@@ -296,12 +322,7 @@ func _AgentService_Attestation_Handler(srv interface{}, ctx context.Context, dec
|
||||
var AgentService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "agent.AgentService",
|
||||
HandlerType: (*AgentServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Attestation",
|
||||
Handler: _AgentService_Attestation_Handler,
|
||||
},
|
||||
},
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Algo",
|
||||
@@ -318,6 +339,11 @@ var AgentService_ServiceDesc = grpc.ServiceDesc{
|
||||
Handler: _AgentService_Result_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "Attestation",
|
||||
Handler: _AgentService_Attestation_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "agent/agent.proto",
|
||||
}
|
||||
|
||||
@@ -6,15 +6,20 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/go-kit/kit/transport/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const bufferSize = 1024 * 1024
|
||||
const (
|
||||
bufferSize = 1024 * 1024
|
||||
FileSizeKey = "file-size"
|
||||
)
|
||||
|
||||
var _ agent.AgentServiceServer = (*grpcServer)(nil)
|
||||
|
||||
@@ -156,12 +161,16 @@ func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_
|
||||
}
|
||||
rr := res.(*agent.ResultResponse)
|
||||
|
||||
reusltBuffer := bytes.NewBuffer(rr.File)
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
resultBuffer := bytes.NewBuffer(rr.File)
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := reusltBuffer.Read(buf)
|
||||
n, err := resultBuffer.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
@@ -177,11 +186,34 @@ func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) Attestation(ctx context.Context, req *agent.AttestationRequest) (*agent.AttestationResponse, error) {
|
||||
_, res, err := s.attestation.ServeGRPC(ctx, req)
|
||||
func (s *grpcServer) Attestation(req *agent.AttestationRequest, stream agent.AgentService_AttestationServer) error {
|
||||
_, res, err := s.attestation.ServeGRPC(stream.Context(), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
rr := res.(*agent.AttestationResponse)
|
||||
return rr, nil
|
||||
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
attestationBuffer := bytes.NewBuffer(rr.File)
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := attestationBuffer.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
if err := stream.Send(&agent.AttestationResponse{File: buf[:n]}); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type MockAgentService_AlgoServer struct {
|
||||
@@ -64,11 +65,34 @@ func (m *MockAgentService_ResultServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_ResultServer) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAgentService_ResultServer) Send(resp *agent.ResultResponse) error {
|
||||
args := m.Called(resp)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockAgentService_AttestationServer struct {
|
||||
grpc.ServerStream
|
||||
mock.Mock
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *MockAgentService_AttestationServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_AttestationServer) Send(resp *agent.AttestationResponse) error {
|
||||
args := m.Called(resp)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_AttestationServer) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAlgo(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
@@ -124,12 +148,14 @@ func TestAttestation(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_AttestationServer{ctx: context.Background()}
|
||||
mockStream.On("Send", mock.AnythingOfType("*agent.AttestationResponse")).Return(nil)
|
||||
|
||||
reportData := [agent.ReportDataSize]byte{}
|
||||
mockService.On("Attestation", mock.Anything, reportData).Return([]byte("attestation data"), nil)
|
||||
|
||||
resp, err := server.Attestation(context.Background(), &agent.AttestationRequest{ReportData: reportData[:]})
|
||||
err := server.Attestation(&agent.AttestationRequest{ReportData: reportData[:]}, mockStream)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("attestation data"), resp.File)
|
||||
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.34.2
|
||||
// protoc v5.27.3
|
||||
// protoc v5.28.1
|
||||
// source: manager/manager.proto
|
||||
|
||||
package manager
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.4.0
|
||||
// - protoc v5.27.3
|
||||
// - protoc v5.28.1
|
||||
// source: manager/manager.proto
|
||||
|
||||
package manager
|
||||
|
||||
@@ -4,12 +4,15 @@ package progressbar
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
)
|
||||
|
||||
func TestRenderProgressBarWithMockedWidth(t *testing.T) {
|
||||
@@ -130,3 +133,192 @@ func TestUpdateProgress(t *testing.T) {
|
||||
assert.Equal(t, 75, pb.currentUploadedBytes)
|
||||
assert.Equal(t, 75, pb.currentUploadPercentage)
|
||||
}
|
||||
|
||||
type MockResultStream struct {
|
||||
mock.Mock
|
||||
agent.AgentService_ResultClient
|
||||
}
|
||||
|
||||
func (m *MockResultStream) Recv() (*agent.ResultResponse, error) {
|
||||
args := m.Called()
|
||||
if res := args.Get(0); res != nil {
|
||||
return res.(*agent.ResultResponse), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
type MockAttestationStream struct {
|
||||
mock.Mock
|
||||
agent.AgentService_AttestationClient
|
||||
}
|
||||
|
||||
func (m *MockAttestationStream) Recv() (*agent.AttestationResponse, error) {
|
||||
args := m.Called()
|
||||
if res := args.Get(0); res != nil {
|
||||
return res.(*agent.AttestationResponse), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
func TestReceiveResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
totalSize int
|
||||
chunks [][]byte
|
||||
setupMock func(*MockResultStream)
|
||||
wantResult []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "successful single chunk receive",
|
||||
description: "Receiving result",
|
||||
totalSize: 5,
|
||||
chunks: [][]byte{[]byte("hello")},
|
||||
setupMock: func(m *MockResultStream) {
|
||||
m.On("Recv").Return(&agent.ResultResponse{File: []byte("hello")}, nil).Once()
|
||||
m.On("Recv").Return(nil, io.EOF).Once()
|
||||
},
|
||||
wantResult: []byte("hello"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "successful multi-chunk receive",
|
||||
description: "Receiving result",
|
||||
totalSize: 10,
|
||||
chunks: [][]byte{[]byte("hello"), []byte("world")},
|
||||
setupMock: func(m *MockResultStream) {
|
||||
m.On("Recv").Return(&agent.ResultResponse{File: []byte("hello")}, nil).Once()
|
||||
m.On("Recv").Return(&agent.ResultResponse{File: []byte("world")}, nil).Once()
|
||||
m.On("Recv").Return(nil, io.EOF).Once()
|
||||
},
|
||||
wantResult: []byte("helloworld"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "stream error",
|
||||
description: "Receiving result",
|
||||
totalSize: 5,
|
||||
setupMock: func(m *MockResultStream) {
|
||||
m.On("Recv").Return(nil, errors.New("stream error")).Once()
|
||||
},
|
||||
wantResult: nil,
|
||||
wantErr: errors.New("stream error"),
|
||||
},
|
||||
{
|
||||
name: "empty result",
|
||||
description: "Receiving result",
|
||||
totalSize: 0,
|
||||
setupMock: func(m *MockResultStream) {
|
||||
m.On("Recv").Return(nil, io.EOF).Once()
|
||||
},
|
||||
wantResult: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStream := &MockResultStream{}
|
||||
tt.setupMock(mockStream)
|
||||
|
||||
p := New(true)
|
||||
// Disable terminal width check for tests
|
||||
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
||||
|
||||
result, err := p.ReceiveResult(tt.description, tt.totalSize, mockStream)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
}
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceiveAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
totalSize int
|
||||
chunks [][]byte
|
||||
setupMock func(*MockAttestationStream)
|
||||
wantResult []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "successful single chunk receive",
|
||||
description: "Receiving attestation",
|
||||
totalSize: 5,
|
||||
chunks: [][]byte{[]byte("proof")},
|
||||
setupMock: func(m *MockAttestationStream) {
|
||||
m.On("Recv").Return(&agent.AttestationResponse{File: []byte("proof")}, nil).Once()
|
||||
m.On("Recv").Return(nil, io.EOF).Once()
|
||||
},
|
||||
wantResult: []byte("proof"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "successful multi-chunk receive",
|
||||
description: "Receiving attestation",
|
||||
totalSize: 15,
|
||||
chunks: [][]byte{[]byte("proof"), []byte("signature")},
|
||||
setupMock: func(m *MockAttestationStream) {
|
||||
m.On("Recv").Return(&agent.AttestationResponse{File: []byte("proof")}, nil).Once()
|
||||
m.On("Recv").Return(&agent.AttestationResponse{File: []byte("signature")}, nil).Once()
|
||||
m.On("Recv").Return(nil, io.EOF).Once()
|
||||
},
|
||||
wantResult: []byte("proofsignature"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "stream error",
|
||||
description: "Receiving attestation",
|
||||
totalSize: 5,
|
||||
setupMock: func(m *MockAttestationStream) {
|
||||
m.On("Recv").Return(nil, errors.New("attestation error")).Once()
|
||||
},
|
||||
wantResult: nil,
|
||||
wantErr: errors.New("attestation error"),
|
||||
},
|
||||
{
|
||||
name: "size mismatch",
|
||||
description: "Receiving attestation",
|
||||
totalSize: 3,
|
||||
chunks: [][]byte{[]byte("toolong")},
|
||||
setupMock: func(m *MockAttestationStream) {
|
||||
m.On("Recv").Return(&agent.AttestationResponse{File: []byte("toolong")}, nil).Once()
|
||||
},
|
||||
wantResult: nil,
|
||||
wantErr: errors.New("progress update exceeds total bytes: attempted to add 7 bytes, but only 3 bytes remain"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStream := &MockAttestationStream{}
|
||||
tt.setupMock(mockStream)
|
||||
|
||||
p := New(true)
|
||||
// Disable terminal width check for tests
|
||||
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
||||
|
||||
result, err := p.ReceiveAttestation(tt.description, tt.totalSize, mockStream)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
}
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,11 +72,13 @@ type ProgressBar struct {
|
||||
description string
|
||||
maxWidth int
|
||||
TerminalWidthFunc func() (int, error)
|
||||
isDownload bool
|
||||
}
|
||||
|
||||
func New() *ProgressBar {
|
||||
func New(isDownload bool) *ProgressBar {
|
||||
return &ProgressBar{
|
||||
TerminalWidthFunc: terminalWidth,
|
||||
isDownload: isDownload,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,11 +220,14 @@ func (p *ProgressBar) renderProgressBar() error {
|
||||
return fmt.Errorf("failed to clear progress bar: %v", err)
|
||||
}
|
||||
|
||||
// Emoji to indicate progress action (📥 for datasets).
|
||||
// Choose emoji based on operation type and content
|
||||
emoji := "🚀 "
|
||||
if strings.Contains(p.description, "data") {
|
||||
emoji = "📦 "
|
||||
} else if p.isDownload {
|
||||
emoji = "📥 "
|
||||
}
|
||||
|
||||
if _, err := builder.WriteString(color.New(color.FgYellow).Sprint(emoji)); err != nil {
|
||||
return fmt.Errorf("failed to add emoji: %v", err)
|
||||
}
|
||||
@@ -297,3 +302,57 @@ func (p *ProgressBar) clearProgressBar() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient) ([]byte, error) {
|
||||
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
||||
response, err := stream.Recv()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response.File, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient) ([]byte, error) {
|
||||
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
||||
response, err := stream.Recv()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response.File, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error)) ([]byte, error) {
|
||||
p.reset(description, totalSize)
|
||||
p.isDownload = true
|
||||
|
||||
var result []byte
|
||||
for {
|
||||
chunk, err := recv()
|
||||
if err == io.EOF {
|
||||
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chunkSize := len(chunk)
|
||||
if err = p.updateProgress(chunkSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, chunk...)
|
||||
|
||||
if err := p.renderProgressBar(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
+49
-30
@@ -12,10 +12,11 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/pkg/progressbar"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -30,9 +31,11 @@ type SDK interface {
|
||||
}
|
||||
|
||||
const (
|
||||
size64 = 64
|
||||
algoProgressBarDescription = "Uploading algorithm"
|
||||
dataProgressBarDescription = "Uploading data"
|
||||
size64 = 64
|
||||
algoProgressBarDescription = "Uploading algorithm"
|
||||
dataProgressBarDescription = "Uploading data"
|
||||
resultProgressDescription = "Downloading result"
|
||||
attestationProgressDescription = "Downloading attestation"
|
||||
)
|
||||
|
||||
type agentSDK struct {
|
||||
@@ -62,12 +65,8 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe
|
||||
algoBuffer := bytes.NewBuffer(algorithm.Algorithm)
|
||||
reqBuffer := bytes.NewBuffer(algorithm.Requirements)
|
||||
|
||||
pb := progressbar.New()
|
||||
if err := pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
pb := progressbar.New(false)
|
||||
return pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream)
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey any) error {
|
||||
@@ -86,12 +85,8 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
|
||||
}
|
||||
dataBuffer := bytes.NewBuffer(dataset.Dataset)
|
||||
|
||||
pb := progressbar.New()
|
||||
if err := pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
pb := progressbar.New(false)
|
||||
return pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream)
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
|
||||
@@ -108,19 +103,25 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result []byte
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, response.File...)
|
||||
incomingmd, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||
|
||||
if len(fileSizeStr) == 0 {
|
||||
fileSizeStr = append(fileSizeStr, "0")
|
||||
}
|
||||
|
||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pb := progressbar.New(true)
|
||||
|
||||
return pb.ReceiveResult(resultProgressDescription, fileSize, stream)
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) {
|
||||
@@ -128,12 +129,30 @@ func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) (
|
||||
ReportData: reportData[:],
|
||||
}
|
||||
|
||||
response, err := sdk.client.Attestation(ctx, request)
|
||||
stream, err := sdk.client.Attestation(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response.File, nil
|
||||
incomingmd, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||
|
||||
if len(fileSizeStr) == 0 {
|
||||
fileSizeStr = append(fileSizeStr, "0")
|
||||
}
|
||||
|
||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pb := progressbar.New(true)
|
||||
|
||||
return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream)
|
||||
}
|
||||
|
||||
func signData(userID string, privKey crypto.Signer) ([]byte, error) {
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
Reference in New Issue
Block a user