mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
5377dd4d7f
* Refactor mock interfaces to use 'any' instead of 'interface{}' for improved type safety and readability across multiple files in the manager and pkg directories.
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
* Update Go version to 1.25.x in CI workflows and remove obsolete Go package files
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
* Add mock implementations for various components in the attestation and SDK packages
- Created mock for MeasurementProvider in pkg/attestation/cmdconfig/mocks/mocks_test.go
- Created mock for Provider in pkg/attestation/mocks/mocks_test.go
- Created mock for Client in pkg/clients/grpc/mocks/mocks_test.go
- Created mock for SDK in pkg/sdk/mocks/mocks_test.go
These mocks are generated using mockery and are intended for unit testing purposes.
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
* Remove autogenerated mock files and update mock usage in tests
- Deleted mocks for gRPC clients in pkg/clients/grpc/mocks/mocks_test.go and pkg/sdk/mocks/mocks_test.go.
- Updated test files in pkg/progressbar/progress_test.go to use the new mock structure without type parameters for gRPC client interfaces.
- Refactored mock generation in pkg/sdk/mocks/sdk.go to streamline the mock creation process and ensure consistency across mock methods.
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
* Update protobuf generated files for events and manager
- Bump protoc-gen-go version from v1.36.5 to v1.36.8 in events.pb.go and manager.pb.go.
- Refactor raw descriptor definitions in events.pb.go and manager.pb.go to use string concatenation for better readability and maintainability.
- Ensure compatibility with the latest protobuf specifications and improve code generation consistency.
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
* Update test commands to use GOTOOLCHAIN for consistent Go version handling
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
* Fix GOTOOLCHAIN usage in test command for consistency
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
---------
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
400 lines
9.5 KiB
Go
400 lines
9.5 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package progressbar
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/fatih/color"
|
|
"github.com/ultravioletrs/cocos/agent"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
const (
|
|
leftBracket = "["
|
|
rightBracket = "]"
|
|
bufferSize = 1024 * 1024
|
|
)
|
|
|
|
var (
|
|
_ streamSender = (*algoClientWrapper)(nil)
|
|
_ streamSender = (*dataClientWrapper)(nil)
|
|
warnOnlyOnce = false
|
|
)
|
|
|
|
type streamSender interface {
|
|
Send(any) error
|
|
CloseAndRecv() (any, error)
|
|
}
|
|
|
|
type algoClientWrapper struct {
|
|
client agent.AgentService_AlgoClient
|
|
}
|
|
|
|
func (a *algoClientWrapper) Send(req any) error {
|
|
algoReq, ok := req.(*agent.AlgoRequest)
|
|
if !ok {
|
|
return fmt.Errorf("expected *AlgoRequest, got %T", req)
|
|
}
|
|
|
|
return a.client.Send(algoReq)
|
|
}
|
|
|
|
func (a *algoClientWrapper) CloseAndRecv() (any, error) {
|
|
return a.client.CloseAndRecv()
|
|
}
|
|
|
|
type dataClientWrapper struct {
|
|
client agent.AgentService_DataClient
|
|
}
|
|
|
|
func (a *dataClientWrapper) Send(req any) error {
|
|
dataReq, ok := req.(*agent.DataRequest)
|
|
if !ok {
|
|
return fmt.Errorf("expected *DataRequest, got %T", req)
|
|
}
|
|
|
|
return a.client.Send(dataReq)
|
|
}
|
|
|
|
func (a *dataClientWrapper) CloseAndRecv() (any, error) {
|
|
return a.client.CloseAndRecv()
|
|
}
|
|
|
|
type ProgressBar struct {
|
|
numberOfBytes int
|
|
currentUploadedBytes int
|
|
currentUploadPercentage int
|
|
description string
|
|
maxWidth int
|
|
TerminalWidthFunc func() (int, error)
|
|
isDownload bool
|
|
}
|
|
|
|
func New(isDownload bool) *ProgressBar {
|
|
return &ProgressBar{
|
|
TerminalWidthFunc: terminalWidth,
|
|
isDownload: isDownload,
|
|
}
|
|
}
|
|
|
|
func (p *ProgressBar) SendAlgorithm(description string, algo, req *os.File, stream agent.AgentService_AlgoClient) error {
|
|
algoFileInfo, err := algo.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
reqSize := 0
|
|
if req != nil {
|
|
reqFileInfo, err := req.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
reqSize = int(reqFileInfo.Size())
|
|
}
|
|
|
|
totalSize := int(algoFileInfo.Size()) + reqSize
|
|
p.reset(description, totalSize)
|
|
|
|
wrapper := &algoClientWrapper{client: stream}
|
|
|
|
// Send req first
|
|
if req != nil {
|
|
if err := p.sendBuffer(req, wrapper, func(data []byte) any {
|
|
return &agent.AlgoRequest{Requirements: data}
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Then send algo
|
|
if err := p.sendBuffer(algo, wrapper, func(data []byte) any {
|
|
return &agent.AlgoRequest{Algorithm: data}
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = wrapper.CloseAndRecv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *ProgressBar) SendData(description, filename string, file *os.File, stream agent.AgentService_DataClient) error {
|
|
return p.sendData(description, file, &dataClientWrapper{client: stream}, func(data []byte) any {
|
|
return &agent.DataRequest{Dataset: data, Filename: filename}
|
|
})
|
|
}
|
|
|
|
func (p *ProgressBar) sendData(description string, file *os.File, stream streamSender, createRequest func([]byte) any) error {
|
|
dataInfo, err := file.Stat()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.reset(description, int(dataInfo.Size()))
|
|
|
|
buf := make([]byte, bufferSize)
|
|
|
|
for {
|
|
n, err := file.Read(buf)
|
|
if err == io.EOF {
|
|
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
|
return err
|
|
}
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = p.updateProgress(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := stream.Send(createRequest(buf[:n])); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := p.renderProgressBar(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
_, err = stream.CloseAndRecv()
|
|
return err
|
|
}
|
|
|
|
func (p *ProgressBar) sendBuffer(file *os.File, stream streamSender, createRequest func([]byte) any) error {
|
|
buf := make([]byte, bufferSize)
|
|
|
|
for {
|
|
n, err := file.Read(buf)
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = p.updateProgress(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := stream.Send(createRequest(buf[:n])); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := p.renderProgressBar(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *ProgressBar) reset(description string, totalBytes int) {
|
|
p.currentUploadedBytes = 0
|
|
p.currentUploadPercentage = 0
|
|
p.numberOfBytes = totalBytes
|
|
p.description = description
|
|
}
|
|
|
|
func (p *ProgressBar) updateProgress(bytesRead int) error {
|
|
if p.currentUploadedBytes+bytesRead > p.numberOfBytes {
|
|
return fmt.Errorf("progress update exceeds total bytes: attempted to add %d bytes, but only %d bytes remain", bytesRead, p.numberOfBytes-p.currentUploadedBytes)
|
|
}
|
|
|
|
p.currentUploadedBytes += bytesRead
|
|
p.currentUploadPercentage = p.currentUploadedBytes * 100 / p.numberOfBytes
|
|
|
|
return nil
|
|
}
|
|
|
|
// Progress bar example: 📦 Uploading algorithm... [█████░░░░░░░░░░░░] [25%].
|
|
func (p *ProgressBar) renderProgressBar() error {
|
|
var builder strings.Builder
|
|
|
|
// Get terminal width.
|
|
width, err := p.TerminalWidthFunc()
|
|
if err != nil {
|
|
if !warnOnlyOnce {
|
|
color.Red("Progress bar could not be rendered")
|
|
warnOnlyOnce = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if p.maxWidth < width {
|
|
p.maxWidth = width
|
|
}
|
|
|
|
if err := p.clearProgressBar(); err != nil {
|
|
return fmt.Errorf("failed to clear progress bar: %v", err)
|
|
}
|
|
|
|
// Choose emoji based on operation type and content
|
|
emoji := "🚀 "
|
|
if strings.Contains(p.description, "data") {
|
|
emoji = "📦 "
|
|
} else if p.isDownload {
|
|
emoji = "📥 "
|
|
}
|
|
|
|
if _, err := builder.WriteString(color.New(color.FgYellow).Sprint(emoji)); err != nil {
|
|
return fmt.Errorf("failed to add emoji: %v", err)
|
|
}
|
|
|
|
// The progress bar starts with the description.
|
|
description := color.New(color.FgYellow).Sprintf("%s ", p.description)
|
|
if _, err := builder.WriteString(description); err != nil {
|
|
return fmt.Errorf("failed to add description: %v", err)
|
|
}
|
|
|
|
// Add left bracket (colored).
|
|
leftBracket := color.New(color.FgBlue).Sprint(leftBracket)
|
|
if _, err := builder.WriteString(leftBracket); err != nil {
|
|
return fmt.Errorf("failed to add left bracket: %v", err)
|
|
}
|
|
|
|
// Calculate the progress bar's width.
|
|
progressWidth := width - builder.Len() - len(rightBracket+" [100%]")
|
|
numOfCharactersBody := progressWidth * p.currentUploadPercentage / 100
|
|
if numOfCharactersBody == 0 {
|
|
numOfCharactersBody = 1
|
|
}
|
|
|
|
numOfCharactersPadding := progressWidth - numOfCharactersBody
|
|
|
|
// Using unicode block characters for a smooth bar.
|
|
progress := color.New(color.FgGreen).Sprint(strings.Repeat("█", numOfCharactersBody))
|
|
if _, err := builder.WriteString(progress); err != nil {
|
|
return fmt.Errorf("failed to add progress strings: %v", err)
|
|
}
|
|
|
|
// Add the unfilled part (light blocks as padding).
|
|
padding := strings.Repeat("░", numOfCharactersPadding)
|
|
if _, err := builder.WriteString(padding); err != nil {
|
|
return fmt.Errorf("failed to add padding: %v", err)
|
|
}
|
|
|
|
// Add right bracket to progress bar.
|
|
rightBracket := color.New(color.FgBlue).Sprint("]")
|
|
if _, err := builder.WriteString(rightBracket); err != nil {
|
|
return fmt.Errorf("failed to add right bracket: %v", err)
|
|
}
|
|
|
|
// Add the percentage at the end inside square brackets.
|
|
strCurrentUploadPercentage := color.New(color.FgGreen).Sprintf(" [%d%%]", p.currentUploadPercentage)
|
|
if _, err := builder.WriteString(strCurrentUploadPercentage); err != nil {
|
|
return fmt.Errorf("failed to add upload percentage: %v", err)
|
|
}
|
|
|
|
// Write progress bar to the console.
|
|
if _, err := io.WriteString(os.Stdout, builder.String()); err != nil {
|
|
return fmt.Errorf("failed to write string: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func terminalWidth() (int, error) {
|
|
width, _, err := term.GetSize(int(os.Stdout.Fd()))
|
|
if err == nil {
|
|
return width, nil
|
|
}
|
|
|
|
return 0, err
|
|
}
|
|
|
|
func (p *ProgressBar) clearProgressBar() error {
|
|
emptySpace := fmt.Sprintf("\r%s\r", strings.Repeat(" ", p.maxWidth))
|
|
if _, err := io.WriteString(os.Stdout, emptySpace); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient, resultFile *os.File) error {
|
|
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
|
response, err := stream.Recv()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response.File, nil
|
|
}, resultFile)
|
|
}
|
|
|
|
func (p *ProgressBar) ReceiveIMAMeasurements(description string, totalSize int, stream agent.AgentService_IMAMeasurementsClient, resultFile *os.File) ([]byte, error) {
|
|
pcr10 := make([]byte, vtpm.Hash1)
|
|
err := p.receiveStream(description, totalSize, func() ([]byte, error) {
|
|
response, err := stream.Recv()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
copy(pcr10, response.Pcr10[:20])
|
|
|
|
return response.File, nil
|
|
}, resultFile)
|
|
|
|
return pcr10, err
|
|
}
|
|
|
|
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient, attestationFile *os.File) error {
|
|
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
|
response, err := stream.Recv()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response.File, nil
|
|
}, attestationFile)
|
|
}
|
|
|
|
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error), file *os.File) error {
|
|
p.reset(description, totalSize)
|
|
p.isDownload = true
|
|
|
|
for {
|
|
chunk, err := recv()
|
|
if err == io.EOF {
|
|
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
|
return err
|
|
}
|
|
break
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
chunkSize := len(chunk)
|
|
if err = p.updateProgress(chunkSize); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := file.Write(chunk); err != nil {
|
|
return err
|
|
}
|
|
if err := p.renderProgressBar(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|