mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
46b94204df
* improve file streaming Signed-off-by: Sammy Oina <sammyoina@gmail.com> * error check Signed-off-by: Sammy Oina <sammyoina@gmail.com> * empty line Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * send buffer test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * stream data and attestation Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fumpt Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * mocks Signed-off-by: Sammy Oina <sammyoina@gmail.com> * value check Signed-off-by: Sammy Oina <sammyoina@gmail.com> * more value checks Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fumpt Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * all files Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
383 lines
9.1 KiB
Go
383 lines
9.1 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package progressbar
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/fatih/color"
|
|
"github.com/ultravioletrs/cocos/agent"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
const (
|
|
leftBracket = "["
|
|
rightBracket = "]"
|
|
bufferSize = 1024 * 1024
|
|
)
|
|
|
|
var (
|
|
_ streamSender = (*algoClientWrapper)(nil)
|
|
_ streamSender = (*dataClientWrapper)(nil)
|
|
warnOnlyOnce = false
|
|
)
|
|
|
|
type streamSender interface {
|
|
Send(interface{}) error
|
|
CloseAndRecv() (interface{}, error)
|
|
}
|
|
|
|
type algoClientWrapper struct {
|
|
client agent.AgentService_AlgoClient
|
|
}
|
|
|
|
func (a *algoClientWrapper) Send(req interface{}) error {
|
|
algoReq, ok := req.(*agent.AlgoRequest)
|
|
if !ok {
|
|
return fmt.Errorf("expected *AlgoRequest, got %T", req)
|
|
}
|
|
|
|
return a.client.Send(algoReq)
|
|
}
|
|
|
|
func (a *algoClientWrapper) CloseAndRecv() (interface{}, error) {
|
|
return a.client.CloseAndRecv()
|
|
}
|
|
|
|
type dataClientWrapper struct {
|
|
client agent.AgentService_DataClient
|
|
}
|
|
|
|
func (a *dataClientWrapper) Send(req interface{}) error {
|
|
dataReq, ok := req.(*agent.DataRequest)
|
|
if !ok {
|
|
return fmt.Errorf("expected *DataRequest, got %T", req)
|
|
}
|
|
|
|
return a.client.Send(dataReq)
|
|
}
|
|
|
|
func (a *dataClientWrapper) CloseAndRecv() (interface{}, error) {
|
|
return a.client.CloseAndRecv()
|
|
}
|
|
|
|
type ProgressBar struct {
|
|
numberOfBytes int
|
|
currentUploadedBytes int
|
|
currentUploadPercentage int
|
|
description string
|
|
maxWidth int
|
|
TerminalWidthFunc func() (int, error)
|
|
isDownload bool
|
|
}
|
|
|
|
func New(isDownload bool) *ProgressBar {
|
|
return &ProgressBar{
|
|
TerminalWidthFunc: terminalWidth,
|
|
isDownload: isDownload,
|
|
}
|
|
}
|
|
|
|
func (p *ProgressBar) SendAlgorithm(description string, algo, req *os.File, stream agent.AgentService_AlgoClient) error {
|
|
algoFileInfo, err := algo.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
reqSize := 0
|
|
if req != nil {
|
|
reqFileInfo, err := req.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
reqSize = int(reqFileInfo.Size())
|
|
}
|
|
|
|
totalSize := int(algoFileInfo.Size()) + reqSize
|
|
p.reset(description, totalSize)
|
|
|
|
wrapper := &algoClientWrapper{client: stream}
|
|
|
|
// Send req first
|
|
if req != nil {
|
|
if err := p.sendBuffer(req, wrapper, func(data []byte) interface{} {
|
|
return &agent.AlgoRequest{Requirements: data}
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Then send algo
|
|
if err := p.sendBuffer(algo, wrapper, func(data []byte) interface{} {
|
|
return &agent.AlgoRequest{Algorithm: data}
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = wrapper.CloseAndRecv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *ProgressBar) SendData(description, filename string, file *os.File, stream agent.AgentService_DataClient) error {
|
|
return p.sendData(description, file, &dataClientWrapper{client: stream}, func(data []byte) interface{} {
|
|
return &agent.DataRequest{Dataset: data, Filename: filename}
|
|
})
|
|
}
|
|
|
|
func (p *ProgressBar) sendData(description string, file *os.File, stream streamSender, createRequest func([]byte) interface{}) error {
|
|
dataInfo, err := file.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.reset(description, int(dataInfo.Size()))
|
|
|
|
buf := make([]byte, bufferSize)
|
|
|
|
for {
|
|
n, err := file.Read(buf)
|
|
if err == io.EOF {
|
|
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
|
return err
|
|
}
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = p.updateProgress(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := stream.Send(createRequest(buf[:n])); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := p.renderProgressBar(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
_, err = stream.CloseAndRecv()
|
|
return err
|
|
}
|
|
|
|
func (p *ProgressBar) sendBuffer(file *os.File, stream streamSender, createRequest func([]byte) interface{}) error {
|
|
buf := make([]byte, bufferSize)
|
|
|
|
for {
|
|
n, err := file.Read(buf)
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = p.updateProgress(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := stream.Send(createRequest(buf[:n])); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := p.renderProgressBar(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *ProgressBar) reset(description string, totalBytes int) {
|
|
p.currentUploadedBytes = 0
|
|
p.currentUploadPercentage = 0
|
|
p.numberOfBytes = totalBytes
|
|
p.description = description
|
|
}
|
|
|
|
func (p *ProgressBar) updateProgress(bytesRead int) error {
|
|
if p.currentUploadedBytes+bytesRead > p.numberOfBytes {
|
|
return fmt.Errorf("progress update exceeds total bytes: attempted to add %d bytes, but only %d bytes remain", bytesRead, p.numberOfBytes-p.currentUploadedBytes)
|
|
}
|
|
|
|
p.currentUploadedBytes += bytesRead
|
|
p.currentUploadPercentage = p.currentUploadedBytes * 100 / p.numberOfBytes
|
|
|
|
return nil
|
|
}
|
|
|
|
// Progress bar example: 📦 Uploading algorithm... [█████░░░░░░░░░░░░] [25%].
|
|
func (p *ProgressBar) renderProgressBar() error {
|
|
var builder strings.Builder
|
|
|
|
// Get terminal width.
|
|
width, err := p.TerminalWidthFunc()
|
|
if err != nil {
|
|
if !warnOnlyOnce {
|
|
color.Red("Progress bar could not be rendered")
|
|
warnOnlyOnce = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if p.maxWidth < width {
|
|
p.maxWidth = width
|
|
}
|
|
|
|
if err := p.clearProgressBar(); err != nil {
|
|
return fmt.Errorf("failed to clear progress bar: %v", err)
|
|
}
|
|
|
|
// Choose emoji based on operation type and content
|
|
emoji := "🚀 "
|
|
if strings.Contains(p.description, "data") {
|
|
emoji = "📦 "
|
|
} else if p.isDownload {
|
|
emoji = "📥 "
|
|
}
|
|
|
|
if _, err := builder.WriteString(color.New(color.FgYellow).Sprint(emoji)); err != nil {
|
|
return fmt.Errorf("failed to add emoji: %v", err)
|
|
}
|
|
|
|
// The progress bar starts with the description.
|
|
description := color.New(color.FgYellow).Sprintf("%s ", p.description)
|
|
if _, err := builder.WriteString(description); err != nil {
|
|
return fmt.Errorf("failed to add description: %v", err)
|
|
}
|
|
|
|
// Add left bracket (colored).
|
|
leftBracket := color.New(color.FgBlue).Sprint(leftBracket)
|
|
if _, err := builder.WriteString(leftBracket); err != nil {
|
|
return fmt.Errorf("failed to add left bracket: %v", err)
|
|
}
|
|
|
|
// Calculate the progress bar's width.
|
|
progressWidth := width - builder.Len() - len(rightBracket+" [100%]")
|
|
numOfCharactersBody := progressWidth * p.currentUploadPercentage / 100
|
|
if numOfCharactersBody == 0 {
|
|
numOfCharactersBody = 1
|
|
}
|
|
|
|
numOfCharactersPadding := progressWidth - numOfCharactersBody
|
|
|
|
// Using unicode block characters for a smooth bar.
|
|
progress := color.New(color.FgGreen).Sprint(strings.Repeat("█", numOfCharactersBody))
|
|
if _, err := builder.WriteString(progress); err != nil {
|
|
return fmt.Errorf("failed to add progress strings: %v", err)
|
|
}
|
|
|
|
// Add the unfilled part (light blocks as padding).
|
|
padding := strings.Repeat("░", numOfCharactersPadding)
|
|
if _, err := builder.WriteString(padding); err != nil {
|
|
return fmt.Errorf("failed to add padding: %v", err)
|
|
}
|
|
|
|
// Add right bracket to progress bar.
|
|
rightBracket := color.New(color.FgBlue).Sprint("]")
|
|
if _, err := builder.WriteString(rightBracket); err != nil {
|
|
return fmt.Errorf("failed to add right bracket: %v", err)
|
|
}
|
|
|
|
// Add the percentage at the end inside square brackets.
|
|
strCurrentUploadPercentage := color.New(color.FgGreen).Sprintf(" [%d%%]", p.currentUploadPercentage)
|
|
if _, err := builder.WriteString(strCurrentUploadPercentage); err != nil {
|
|
return fmt.Errorf("failed to add upload percentage: %v", err)
|
|
}
|
|
|
|
// Write progress bar to the console.
|
|
if _, err := io.WriteString(os.Stdout, builder.String()); err != nil {
|
|
return fmt.Errorf("failed to write string: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func terminalWidth() (int, error) {
|
|
width, _, err := term.GetSize(int(os.Stdout.Fd()))
|
|
if err == nil {
|
|
return width, nil
|
|
}
|
|
|
|
return 0, err
|
|
}
|
|
|
|
func (p *ProgressBar) clearProgressBar() error {
|
|
emptySpace := fmt.Sprintf("\r%s\r", strings.Repeat(" ", p.maxWidth))
|
|
if _, err := io.WriteString(os.Stdout, emptySpace); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient, resultFile *os.File) error {
|
|
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
|
response, err := stream.Recv()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response.File, nil
|
|
}, resultFile)
|
|
}
|
|
|
|
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient, attestationFile *os.File) error {
|
|
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
|
response, err := stream.Recv()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response.File, nil
|
|
}, attestationFile)
|
|
}
|
|
|
|
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error), file *os.File) error {
|
|
p.reset(description, totalSize)
|
|
p.isDownload = true
|
|
|
|
for {
|
|
chunk, err := recv()
|
|
if err == io.EOF {
|
|
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
|
return err
|
|
}
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
chunkSize := len(chunk)
|
|
if err = p.updateProgress(chunkSize); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := file.Write(chunk); err != nil {
|
|
return err
|
|
}
|
|
if err := p.renderProgressBar(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|