COCOS-256 - Progress bar on downloads (#290)

* add progress bar for downloads

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* better error handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix test and refactor

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix failing test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add test coverage

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-10-29 14:33:00 +03:00
committed by GitHub
parent 765513b387
commit 6043ad150b
13 changed files with 436 additions and 83 deletions
+1 -1
View File
@@ -33,7 +33,7 @@ jobs:
- name: Set up protoc
run: |
PROTOC_VERSION=27.3
PROTOC_VERSION=28.1
PROTOC_GEN_VERSION=v1.34.2
PROTOC_GRPC_VERSION=v1.4.0
+1 -2
View File
@@ -1,6 +1,5 @@
build
build
cmd/manager/img
.cov
@@ -8,7 +7,7 @@ cmd/manager/img
*.pem
dist/
results.zip
*.zip
*.spec
*.tar
+5 -5
View File
@@ -4,7 +4,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.2
// protoc v5.27.3
// protoc v5.28.1
// source: agent/agent.proto
package agent
@@ -413,7 +413,7 @@ var file_agent_agent_proto_rawDesc = []byte{
0x28, 0x0c, 0x52, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x44, 0x61, 0x74, 0x61, 0x22, 0x29,
0x0a, 0x13, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfb, 0x01, 0x0a, 0x0c, 0x41, 0x67,
0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67,
0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c,
0x67, 0x6f, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41,
@@ -425,12 +425,12 @@ var file_agent_agent_proto_rawDesc = []byte{
0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73,
0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12,
0x46, 0x0a, 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19,
0x48, 0x0a, 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19,
0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e,
0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, 0x67, 0x65,
0x6e, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61,
0x67, 0x65, 0x6e, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
+1 -1
View File
@@ -11,7 +11,7 @@ service AgentService {
rpc Algo(stream AlgoRequest) returns (AlgoResponse) {}
rpc Data(stream DataRequest) returns (DataResponse) {}
rpc Result(ResultRequest) returns (stream ResultResponse) {}
rpc Attestation(AttestationRequest) returns (AttestationResponse) {}
rpc Attestation(AttestationRequest) returns (stream AttestationResponse) {}
}
message AlgoRequest {
+56 -30
View File
@@ -4,7 +4,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.4.0
// - protoc v5.27.3
// - protoc v5.28.1
// source: agent/agent.proto
package agent
@@ -35,7 +35,7 @@ type AgentServiceClient interface {
Algo(ctx context.Context, opts ...grpc.CallOption) (AgentService_AlgoClient, error)
Data(ctx context.Context, opts ...grpc.CallOption) (AgentService_DataClient, error)
Result(ctx context.Context, in *ResultRequest, opts ...grpc.CallOption) (AgentService_ResultClient, error)
Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (*AttestationResponse, error)
Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (AgentService_AttestationClient, error)
}
type agentServiceClient struct {
@@ -149,14 +149,37 @@ func (x *agentServiceResultClient) Recv() (*ResultResponse, error) {
return m, nil
}
func (c *agentServiceClient) Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (*AttestationResponse, error) {
func (c *agentServiceClient) Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (AgentService_AttestationClient, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AttestationResponse)
err := c.cc.Invoke(ctx, AgentService_Attestation_FullMethodName, in, out, cOpts...)
stream, err := c.cc.NewStream(ctx, &AgentService_ServiceDesc.Streams[3], AgentService_Attestation_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
return out, nil
x := &agentServiceAttestationClient{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type AgentService_AttestationClient interface {
Recv() (*AttestationResponse, error)
grpc.ClientStream
}
type agentServiceAttestationClient struct {
grpc.ClientStream
}
func (x *agentServiceAttestationClient) Recv() (*AttestationResponse, error) {
m := new(AttestationResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// AgentServiceServer is the server API for AgentService service.
@@ -166,7 +189,7 @@ type AgentServiceServer interface {
Algo(AgentService_AlgoServer) error
Data(AgentService_DataServer) error
Result(*ResultRequest, AgentService_ResultServer) error
Attestation(context.Context, *AttestationRequest) (*AttestationResponse, error)
Attestation(*AttestationRequest, AgentService_AttestationServer) error
mustEmbedUnimplementedAgentServiceServer()
}
@@ -183,8 +206,8 @@ func (UnimplementedAgentServiceServer) Data(AgentService_DataServer) error {
func (UnimplementedAgentServiceServer) Result(*ResultRequest, AgentService_ResultServer) error {
return status.Errorf(codes.Unimplemented, "method Result not implemented")
}
func (UnimplementedAgentServiceServer) Attestation(context.Context, *AttestationRequest) (*AttestationResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Attestation not implemented")
func (UnimplementedAgentServiceServer) Attestation(*AttestationRequest, AgentService_AttestationServer) error {
return status.Errorf(codes.Unimplemented, "method Attestation not implemented")
}
func (UnimplementedAgentServiceServer) mustEmbedUnimplementedAgentServiceServer() {}
@@ -272,22 +295,25 @@ func (x *agentServiceResultServer) Send(m *ResultResponse) error {
return x.ServerStream.SendMsg(m)
}
func _AgentService_Attestation_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AttestationRequest)
if err := dec(in); err != nil {
return nil, err
func _AgentService_Attestation_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(AttestationRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
if interceptor == nil {
return srv.(AgentServiceServer).Attestation(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: AgentService_Attestation_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AgentServiceServer).Attestation(ctx, req.(*AttestationRequest))
}
return interceptor(ctx, in, info, handler)
return srv.(AgentServiceServer).Attestation(m, &agentServiceAttestationServer{ServerStream: stream})
}
type AgentService_AttestationServer interface {
Send(*AttestationResponse) error
grpc.ServerStream
}
type agentServiceAttestationServer struct {
grpc.ServerStream
}
func (x *agentServiceAttestationServer) Send(m *AttestationResponse) error {
return x.ServerStream.SendMsg(m)
}
// AgentService_ServiceDesc is the grpc.ServiceDesc for AgentService service.
@@ -296,12 +322,7 @@ func _AgentService_Attestation_Handler(srv interface{}, ctx context.Context, dec
var AgentService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "agent.AgentService",
HandlerType: (*AgentServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Attestation",
Handler: _AgentService_Attestation_Handler,
},
},
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "Algo",
@@ -318,6 +339,11 @@ var AgentService_ServiceDesc = grpc.ServiceDesc{
Handler: _AgentService_Result_Handler,
ServerStreams: true,
},
{
StreamName: "Attestation",
Handler: _AgentService_Attestation_Handler,
ServerStreams: true,
},
},
Metadata: "agent/agent.proto",
}
+39 -7
View File
@@ -6,15 +6,20 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"github.com/go-kit/kit/transport/grpc"
"github.com/ultravioletrs/cocos/agent"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const bufferSize = 1024 * 1024
const (
bufferSize = 1024 * 1024
FileSizeKey = "file-size"
)
var _ agent.AgentServiceServer = (*grpcServer)(nil)
@@ -156,12 +161,16 @@ func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_
}
rr := res.(*agent.ResultResponse)
reusltBuffer := bytes.NewBuffer(rr.File)
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
return status.Error(codes.Internal, err.Error())
}
resultBuffer := bytes.NewBuffer(rr.File)
buf := make([]byte, bufferSize)
for {
n, err := reusltBuffer.Read(buf)
n, err := resultBuffer.Read(buf)
if err == io.EOF {
break
}
@@ -177,11 +186,34 @@ func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_
return nil
}
func (s *grpcServer) Attestation(ctx context.Context, req *agent.AttestationRequest) (*agent.AttestationResponse, error) {
_, res, err := s.attestation.ServeGRPC(ctx, req)
func (s *grpcServer) Attestation(req *agent.AttestationRequest, stream agent.AgentService_AttestationServer) error {
_, res, err := s.attestation.ServeGRPC(stream.Context(), req)
if err != nil {
return nil, err
return err
}
rr := res.(*agent.AttestationResponse)
return rr, nil
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
return status.Error(codes.Internal, err.Error())
}
attestationBuffer := bytes.NewBuffer(rr.File)
buf := make([]byte, bufferSize)
for {
n, err := attestationBuffer.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return status.Error(codes.Internal, err.Error())
}
if err := stream.Send(&agent.AttestationResponse{File: buf[:n]}); err != nil {
return status.Error(codes.Internal, err.Error())
}
}
return nil
}
+28 -2
View File
@@ -12,6 +12,7 @@ import (
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/mocks"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type MockAgentService_AlgoServer struct {
@@ -64,11 +65,34 @@ func (m *MockAgentService_ResultServer) Context() context.Context {
return m.ctx
}
func (m *MockAgentService_ResultServer) SetHeader(metadata.MD) error {
return nil
}
func (m *MockAgentService_ResultServer) Send(resp *agent.ResultResponse) error {
args := m.Called(resp)
return args.Error(0)
}
type MockAgentService_AttestationServer struct {
grpc.ServerStream
mock.Mock
ctx context.Context
}
func (m *MockAgentService_AttestationServer) Context() context.Context {
return m.ctx
}
func (m *MockAgentService_AttestationServer) Send(resp *agent.AttestationResponse) error {
args := m.Called(resp)
return args.Error(0)
}
func (m *MockAgentService_AttestationServer) SetHeader(metadata.MD) error {
return nil
}
func TestAlgo(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
@@ -124,12 +148,14 @@ func TestAttestation(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
mockStream := &MockAgentService_AttestationServer{ctx: context.Background()}
mockStream.On("Send", mock.AnythingOfType("*agent.AttestationResponse")).Return(nil)
reportData := [agent.ReportDataSize]byte{}
mockService.On("Attestation", mock.Anything, reportData).Return([]byte("attestation data"), nil)
resp, err := server.Attestation(context.Background(), &agent.AttestationRequest{ReportData: reportData[:]})
err := server.Attestation(&agent.AttestationRequest{ReportData: reportData[:]}, mockStream)
assert.NoError(t, err)
assert.Equal(t, []byte("attestation data"), resp.File)
mockService.AssertExpectations(t)
}
+1 -1
View File
@@ -4,7 +4,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.2
// protoc v5.27.3
// protoc v5.28.1
// source: manager/manager.proto
package manager
+1 -1
View File
@@ -4,7 +4,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.4.0
// - protoc v5.27.3
// - protoc v5.28.1
// source: manager/manager.proto
package manager
+192
View File
@@ -4,12 +4,15 @@ package progressbar
import (
"bytes"
"errors"
"io"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent"
)
func TestRenderProgressBarWithMockedWidth(t *testing.T) {
@@ -130,3 +133,192 @@ func TestUpdateProgress(t *testing.T) {
assert.Equal(t, 75, pb.currentUploadedBytes)
assert.Equal(t, 75, pb.currentUploadPercentage)
}
type MockResultStream struct {
mock.Mock
agent.AgentService_ResultClient
}
func (m *MockResultStream) Recv() (*agent.ResultResponse, error) {
args := m.Called()
if res := args.Get(0); res != nil {
return res.(*agent.ResultResponse), args.Error(1)
}
return nil, args.Error(1)
}
type MockAttestationStream struct {
mock.Mock
agent.AgentService_AttestationClient
}
func (m *MockAttestationStream) Recv() (*agent.AttestationResponse, error) {
args := m.Called()
if res := args.Get(0); res != nil {
return res.(*agent.AttestationResponse), args.Error(1)
}
return nil, args.Error(1)
}
func TestReceiveResult(t *testing.T) {
tests := []struct {
name string
description string
totalSize int
chunks [][]byte
setupMock func(*MockResultStream)
wantResult []byte
wantErr error
}{
{
name: "successful single chunk receive",
description: "Receiving result",
totalSize: 5,
chunks: [][]byte{[]byte("hello")},
setupMock: func(m *MockResultStream) {
m.On("Recv").Return(&agent.ResultResponse{File: []byte("hello")}, nil).Once()
m.On("Recv").Return(nil, io.EOF).Once()
},
wantResult: []byte("hello"),
wantErr: nil,
},
{
name: "successful multi-chunk receive",
description: "Receiving result",
totalSize: 10,
chunks: [][]byte{[]byte("hello"), []byte("world")},
setupMock: func(m *MockResultStream) {
m.On("Recv").Return(&agent.ResultResponse{File: []byte("hello")}, nil).Once()
m.On("Recv").Return(&agent.ResultResponse{File: []byte("world")}, nil).Once()
m.On("Recv").Return(nil, io.EOF).Once()
},
wantResult: []byte("helloworld"),
wantErr: nil,
},
{
name: "stream error",
description: "Receiving result",
totalSize: 5,
setupMock: func(m *MockResultStream) {
m.On("Recv").Return(nil, errors.New("stream error")).Once()
},
wantResult: nil,
wantErr: errors.New("stream error"),
},
{
name: "empty result",
description: "Receiving result",
totalSize: 0,
setupMock: func(m *MockResultStream) {
m.On("Recv").Return(nil, io.EOF).Once()
},
wantResult: nil,
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStream := &MockResultStream{}
tt.setupMock(mockStream)
p := New(true)
// Disable terminal width check for tests
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
result, err := p.ReceiveResult(tt.description, tt.totalSize, mockStream)
if tt.wantErr != nil {
assert.Error(t, err)
assert.Equal(t, tt.wantErr.Error(), err.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
}
mockStream.AssertExpectations(t)
})
}
}
func TestReceiveAttestation(t *testing.T) {
tests := []struct {
name string
description string
totalSize int
chunks [][]byte
setupMock func(*MockAttestationStream)
wantResult []byte
wantErr error
}{
{
name: "successful single chunk receive",
description: "Receiving attestation",
totalSize: 5,
chunks: [][]byte{[]byte("proof")},
setupMock: func(m *MockAttestationStream) {
m.On("Recv").Return(&agent.AttestationResponse{File: []byte("proof")}, nil).Once()
m.On("Recv").Return(nil, io.EOF).Once()
},
wantResult: []byte("proof"),
wantErr: nil,
},
{
name: "successful multi-chunk receive",
description: "Receiving attestation",
totalSize: 15,
chunks: [][]byte{[]byte("proof"), []byte("signature")},
setupMock: func(m *MockAttestationStream) {
m.On("Recv").Return(&agent.AttestationResponse{File: []byte("proof")}, nil).Once()
m.On("Recv").Return(&agent.AttestationResponse{File: []byte("signature")}, nil).Once()
m.On("Recv").Return(nil, io.EOF).Once()
},
wantResult: []byte("proofsignature"),
wantErr: nil,
},
{
name: "stream error",
description: "Receiving attestation",
totalSize: 5,
setupMock: func(m *MockAttestationStream) {
m.On("Recv").Return(nil, errors.New("attestation error")).Once()
},
wantResult: nil,
wantErr: errors.New("attestation error"),
},
{
name: "size mismatch",
description: "Receiving attestation",
totalSize: 3,
chunks: [][]byte{[]byte("toolong")},
setupMock: func(m *MockAttestationStream) {
m.On("Recv").Return(&agent.AttestationResponse{File: []byte("toolong")}, nil).Once()
},
wantResult: nil,
wantErr: errors.New("progress update exceeds total bytes: attempted to add 7 bytes, but only 3 bytes remain"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStream := &MockAttestationStream{}
tt.setupMock(mockStream)
p := New(true)
// Disable terminal width check for tests
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
result, err := p.ReceiveAttestation(tt.description, tt.totalSize, mockStream)
if tt.wantErr != nil {
assert.Error(t, err)
assert.Equal(t, tt.wantErr.Error(), err.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
}
mockStream.AssertExpectations(t)
})
}
}
+61 -2
View File
@@ -72,11 +72,13 @@ type ProgressBar struct {
description string
maxWidth int
TerminalWidthFunc func() (int, error)
isDownload bool
}
func New() *ProgressBar {
func New(isDownload bool) *ProgressBar {
return &ProgressBar{
TerminalWidthFunc: terminalWidth,
isDownload: isDownload,
}
}
@@ -218,11 +220,14 @@ func (p *ProgressBar) renderProgressBar() error {
return fmt.Errorf("failed to clear progress bar: %v", err)
}
// Emoji to indicate progress action (📥 for datasets).
// Choose emoji based on operation type and content
emoji := "🚀 "
if strings.Contains(p.description, "data") {
emoji = "📦 "
} else if p.isDownload {
emoji = "📥 "
}
if _, err := builder.WriteString(color.New(color.FgYellow).Sprint(emoji)); err != nil {
return fmt.Errorf("failed to add emoji: %v", err)
}
@@ -297,3 +302,57 @@ func (p *ProgressBar) clearProgressBar() error {
return nil
}
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient) ([]byte, error) {
return p.receiveStream(description, totalSize, func() ([]byte, error) {
response, err := stream.Recv()
if err != nil {
return nil, err
}
return response.File, nil
})
}
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient) ([]byte, error) {
return p.receiveStream(description, totalSize, func() ([]byte, error) {
response, err := stream.Recv()
if err != nil {
return nil, err
}
return response.File, nil
})
}
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error)) ([]byte, error) {
p.reset(description, totalSize)
p.isDownload = true
var result []byte
for {
chunk, err := recv()
if err == io.EOF {
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
return nil, err
}
break
}
if err != nil {
return nil, err
}
chunkSize := len(chunk)
if err = p.updateProgress(chunkSize); err != nil {
return nil, err
}
result = append(result, chunk...)
if err := p.renderProgressBar(); err != nil {
return nil, err
}
}
return result, nil
}
+49 -30
View File
@@ -12,10 +12,11 @@ import (
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"errors"
"io"
"strconv"
"github.com/absmach/magistrala/pkg/errors"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/pkg/progressbar"
"google.golang.org/grpc/metadata"
@@ -30,9 +31,11 @@ type SDK interface {
}
const (
size64 = 64
algoProgressBarDescription = "Uploading algorithm"
dataProgressBarDescription = "Uploading data"
size64 = 64
algoProgressBarDescription = "Uploading algorithm"
dataProgressBarDescription = "Uploading data"
resultProgressDescription = "Downloading result"
attestationProgressDescription = "Downloading attestation"
)
type agentSDK struct {
@@ -62,12 +65,8 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe
algoBuffer := bytes.NewBuffer(algorithm.Algorithm)
reqBuffer := bytes.NewBuffer(algorithm.Requirements)
pb := progressbar.New()
if err := pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream); err != nil {
return err
}
return nil
pb := progressbar.New(false)
return pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream)
}
func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey any) error {
@@ -86,12 +85,8 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
}
dataBuffer := bytes.NewBuffer(dataset.Dataset)
pb := progressbar.New()
if err := pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream); err != nil {
return err
}
return nil
pb := progressbar.New(false)
return pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream)
}
func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
@@ -108,19 +103,25 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
return nil, err
}
var result []byte
for {
response, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
result = append(result, response.File...)
incomingmd, err := stream.Header()
if err != nil {
return nil, err
}
return result, nil
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
if len(fileSizeStr) == 0 {
fileSizeStr = append(fileSizeStr, "0")
}
fileSize, err := strconv.Atoi(fileSizeStr[0])
if err != nil {
return nil, err
}
pb := progressbar.New(true)
return pb.ReceiveResult(resultProgressDescription, fileSize, stream)
}
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) {
@@ -128,12 +129,30 @@ func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) (
ReportData: reportData[:],
}
response, err := sdk.client.Attestation(ctx, request)
stream, err := sdk.client.Attestation(ctx, request)
if err != nil {
return nil, err
}
return response.File, nil
incomingmd, err := stream.Header()
if err != nil {
return nil, err
}
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
if len(fileSizeStr) == 0 {
fileSizeStr = append(fileSizeStr, "0")
}
fileSize, err := strconv.Atoi(fileSizeStr[0])
if err != nil {
return nil, err
}
pb := progressbar.New(true)
return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream)
}
func signData(userID string, privKey crypto.Signer) ([]byte, error) {
+1 -1
View File
@@ -11,10 +11,10 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"errors"
"os"
"testing"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"