COCOS-151 - Add compression/decompression option for CLI/Agent (#200)

* on the fly compression

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* rename file-hash to checksum

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* check error properly

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix connection handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-08-21 12:54:52 +03:00
committed by GitHub
parent e4ef1aae36
commit 899bfb0ec5
12 changed files with 287 additions and 132 deletions
-59
View File
@@ -1,59 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package algorithm
import (
"archive/zip"
"bytes"
"fmt"
"io"
"os"
"path/filepath"
)
// ZipDirectory zips a directory and returns the zipped bytes.
func ZipDirectory() ([]byte, error) {
buf := new(bytes.Buffer)
zipWriter := zip.NewWriter(buf)
err := filepath.Walk(ResultsDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("error walking the path %q: %v", path, err)
}
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(ResultsDir, path)
if err != nil {
return fmt.Errorf("error getting relative path for %q: %v", path, err)
}
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("error opening file %q: %v", path, err)
}
defer file.Close()
zipFile, err := zipWriter.Create(relPath)
if err != nil {
return fmt.Errorf("error creating zip file for %q: %v", path, err)
}
if _, err = io.Copy(zipFile, file); err != nil {
return fmt.Errorf("error copying file %q to zip: %v", path, err)
}
return err
})
if err != nil {
return nil, err
}
if err = zipWriter.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
+2 -1
View File
@@ -7,6 +7,7 @@ import (
"testing"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/internal"
)
func TestZipDirectory(t *testing.T) {
@@ -73,7 +74,7 @@ func TestZipDirectory(t *testing.T) {
}
}
if _, err := algorithm.ZipDirectory(); err != nil {
if _, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir); err != nil {
t.Errorf("ZipDirectory() error = %v", err)
}
})
+17
View File
@@ -6,6 +6,8 @@ import (
"context"
"encoding/json"
"fmt"
"google.golang.org/grpc/metadata"
)
var _ fmt.Stringer = (*Datasets)(nil)
@@ -69,3 +71,18 @@ func IndexFromContext(ctx context.Context) (int, bool) {
index, ok := ctx.Value(ManifestIndexKey{}).(int)
return index, ok
}
const DecompressKey = "decompress"
func DecompressFromContext(ctx context.Context) bool {
vals := metadata.ValueFromIncomingContext(ctx, DecompressKey)
if len(vals) == 0 {
return false
}
return vals[0] == "true"
}
func DecompressToContext(ctx context.Context, decompress bool) context.Context {
return metadata.AppendToOutgoingContext(ctx, DecompressKey, fmt.Sprintf("%t", decompress))
}
+17 -11
View File
@@ -18,6 +18,7 @@ import (
"github.com/ultravioletrs/cocos/agent/algorithm/python"
"github.com/ultravioletrs/cocos/agent/algorithm/wasm"
"github.com/ultravioletrs/cocos/agent/events"
"github.com/ultravioletrs/cocos/internal"
"golang.org/x/crypto/sha3"
)
@@ -191,16 +192,22 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1)
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, dataset.Filename))
if err != nil {
return fmt.Errorf("error creating dataset file: %v", err)
}
if DecompressFromContext(ctx) {
if err := internal.UnzipFromMemory(dataset.Dataset, algorithm.DatasetsDir); err != nil {
return fmt.Errorf("error decompressing dataset: %v", err)
}
} else {
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, dataset.Filename))
if err != nil {
return fmt.Errorf("error creating dataset file: %v", err)
}
if _, err := f.Write(dataset.Dataset); err != nil {
return fmt.Errorf("error writing dataset to file: %v", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("error closing file: %v", err)
if _, err := f.Write(dataset.Dataset); err != nil {
return fmt.Errorf("error writing dataset to file: %v", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("error closing file: %v", err)
}
}
matched = true
@@ -212,7 +219,6 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
return ErrUndeclaredDataset
}
// Check if all datasets have been received
if len(as.computation.Datasets) == 0 {
as.sm.SendEvent(dataReceived)
}
@@ -288,7 +294,7 @@ func (as *agentService) runComputation() {
return
}
results, err := algorithm.ZipDirectory()
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
if err != nil {
as.runError = err
as.sm.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
+26 -6
View File
@@ -16,14 +16,14 @@ make cli
Retrieves attestation information from the SEV guest and saves it to a file.
To retrieve attestation from agent, use the following command:
```bash
./build/cocos-cli agent attestation get '<report_data>'
./build/cocos-cli attestation get '<report_data>'
```
#### Validate attestation
Validates the retrieved attestation information against a specified policy and checks its authenticity.
To validate and verify attestation from agent, use the following command:
```bash
./build/cocos-cli agent attestation validate '<attestation>' --report_data '<report_data>'
./build/cocos-cli attestation validate '<attestation>' --report_data '<report_data>'
```
##### Flags
- --config: Path to a JSON file containing the validation configuration. This can be used to override individual flags.
@@ -62,21 +62,41 @@ To validate and verify attestation from agent, use the following command:
To upload an algorithm, use the following command:
```bash
./build/cocos-cli agent algo /path/to/algorithm <private_key_file_path>
./build/cocos-cli algo /path/to/algorithm <private_key_file_path>
```
##### Flags
- -a, --algorithm string Algorithm type to run (default "bin")
- --python-runtime string Python runtime to use (default "python3")
- -r, --requirements string Python requirements file
#### Upload Dataset
To upload a dataset, use the following command:
```bash
./build/cocos-cli agent data /path/to/dataset.csv <private_key_file_path>
./build/cocos-cli data /path/to/dataset.csv <private_key_file_path>
```
Users can also upload directories which will be compressed on transit. Once received by agent they will be stored as compressed files or decompressed if the user passed the decompression argument.
##### Flags
- -d, --decompress Decompress the dataset on agent
#### Retrieve result
To retrieve the computation result, use the following command:
```bash
./build/cocos-cli agent result <private_key_file_path>
```
./build/cocos-cli result <private_key_file_path>
```
#### Checksum
When defining the manifest dataset and algorithm checksums are required. This can be done as below:
```bash
./build/cocos-cli checksum <path_to_dataset_or_algorithm>
```
+6 -12
View File
@@ -3,32 +3,26 @@
package cli
import (
"encoding/hex"
"log"
"os"
"github.com/spf13/cobra"
"golang.org/x/crypto/sha3"
"github.com/ultravioletrs/cocos/internal"
)
func (cli *CLI) NewFileHashCmd() *cobra.Command {
return &cobra.Command{
Use: "file-hash",
Use: "checksum",
Short: "Compute the sha3-256 hash of a file",
Example: "file-hash <file>",
Example: "checksum <file>",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
fileName := args[0]
path := args[0]
file, err := os.ReadFile(fileName)
hash, err := internal.ChecksumHex(path)
if err != nil {
log.Fatalf("Error reading dataset file: %v", err)
log.Fatalf("Error computing hash: %v", err)
}
hashBytes := sha3.Sum256(file)
hash := hex.EncodeToString(hashBytes[:])
log.Println("Hash of file:", hash)
},
}
+33 -6
View File
@@ -3,6 +3,7 @@
package cli
import (
"context"
"crypto/x509"
"encoding/pem"
"log"
@@ -11,27 +12,45 @@ import (
"github.com/spf13/cobra"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/internal"
"google.golang.org/grpc/metadata"
)
var decompressDataset bool
func (cli *CLI) NewDatasetsCmd() *cobra.Command {
return &cobra.Command{
cmd := &cobra.Command{
Use: "data",
Short: "Upload a dataset",
Example: "data <dataset_path> <private_key_file_path>",
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
datasetFile := args[0]
datasetPath := args[0]
log.Println("Uploading dataset:", datasetFile)
log.Println("Uploading dataset:", datasetPath)
dataset, err := os.ReadFile(datasetFile)
f, err := os.Stat(datasetPath)
if err != nil {
log.Fatalf("Error reading dataset file: %v", err)
}
var dataset []byte
if f.IsDir() {
dataset, err = internal.ZipDirectoryToMemory(datasetPath)
if err != nil {
log.Fatalf("Error zipping dataset directory: %v", err)
}
} else {
dataset, err = os.ReadFile(datasetPath)
if err != nil {
log.Fatalf("Error reading dataset file: %v", err)
}
}
dataReq := agent.Dataset{
Dataset: dataset,
Filename: path.Base(datasetFile),
Filename: path.Base(datasetPath),
}
privKeyFile, err := os.ReadFile(args[1])
@@ -43,13 +62,17 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
privKey := decodeKey(pemBlock)
if err := cli.agentSDK.Data(cmd.Context(), dataReq, privKey); err != nil {
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataReq, privKey); err != nil {
log.Fatalf("Error uploading dataset: %v", err)
}
log.Println("Successfully uploaded dataset")
},
}
cmd.Flags().BoolVarP(&decompressDataset, "decompress", "d", false, "Decompress the dataset on agent")
return cmd
}
func decodeKey(b *pem.Block) interface{} {
@@ -74,3 +97,7 @@ func decodeKey(b *pem.Block) interface{} {
return nil
}
}
func addDatasetMetadata(ctx context.Context) context.Context {
return agent.DecompressToContext(ctx, decompressDataset)
}
+37
View File
@@ -3,9 +3,12 @@
package internal
import (
"encoding/hex"
"io"
"os"
"path/filepath"
"golang.org/x/crypto/sha3"
)
// CopyFile copies a file from srcPath to dstPath.
@@ -46,3 +49,37 @@ func DeleteFilesInDir(dirPath string) error {
return nil
}
// Checksum calculates the SHA3-256 checksum of the file or directory at path.
func Checksum(path string) ([]byte, error) {
file, err := os.Stat(path)
if err != nil {
return nil, err
}
if file.IsDir() {
f, err := ZipDirectoryToMemory(path)
if err != nil {
return nil, err
}
sum := sha3.Sum256(f)
return sum[:], nil
} else {
f, err := os.ReadFile(path)
if err != nil {
return nil, err
}
sum := sha3.Sum256(f)
return sum[:], nil
}
}
// ChecksumHex calculates the SHA3-256 checksum of the file or directory at path and returns it as a hex-encoded string.
func ChecksumHex(path string) (string, error) {
sum, err := Checksum(path)
if err != nil {
return "", err
}
return hex.EncodeToString(sum), nil
}
+102
View File
@@ -0,0 +1,102 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"archive/zip"
"bytes"
"io"
"os"
"path/filepath"
)
func ZipDirectoryToMemory(sourceDir string) ([]byte, error) {
buf := new(bytes.Buffer)
zipWriter := zip.NewWriter(buf)
err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
return err
}
zipHeader, err := zip.FileInfoHeader(info)
if err != nil {
return err
}
zipHeader.Name = relPath
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
if err != nil {
return err
}
fileToZip, err := os.Open(path)
if err != nil {
return err
}
defer fileToZip.Close()
_, err = io.Copy(zipWriterEntry, fileToZip)
return err
})
if err != nil {
zipWriter.Close()
return nil, err
}
if err := zipWriter.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func UnzipFromMemory(zipData []byte, targetDir string) error {
reader := bytes.NewReader(zipData)
zipReader, err := zip.NewReader(reader, int64(len(zipData)))
if err != nil {
return err
}
for _, file := range zipReader.File {
filePath := filepath.Join(targetDir, file.Name)
if file.FileInfo().IsDir() {
if err := os.MkdirAll(filePath, os.ModePerm); err != nil {
return err
}
continue
}
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return err
}
srcFile, err := file.Open()
if err != nil {
return err
}
defer srcFile.Close()
dstFile, err := os.Create(filePath)
if err != nil {
return err
}
defer dstFile.Close()
if _, err := io.Copy(dstFile, srcFile); err != nil {
return err
}
}
return nil
}
+4 -1
View File
@@ -82,7 +82,10 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
return err
}
ctx = metadata.NewOutgoingContext(ctx, md)
for k, v := range md {
ctx = metadata.AppendToOutgoingContext(ctx, k, v[0])
}
stream, err := sdk.client.Data(ctx)
if err != nil {
sdk.logger.Error("Failed to call Data RPC")
+9 -10
View File
@@ -13,11 +13,11 @@ import (
mglog "github.com/absmach/magistrala/logger"
"github.com/caarlos0/env/v11"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
"github.com/ultravioletrs/cocos/pkg/manager"
"golang.org/x/crypto/sha3"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -44,11 +44,6 @@ type svc struct {
func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, auth credentials.AuthInfo) {
s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAdress))
algo, err := os.ReadFile(algoPath)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to read algorithm file: %s", err))
return
}
pubKey, err := os.ReadFile(pubKeyFile)
if err != nil {
@@ -63,16 +58,20 @@ func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, au
s.logger.Error(fmt.Sprintf("data file does not exist: %s", dataPath))
return
}
data, err := os.ReadFile(dataPath)
dataHash, err := internal.Checksum(dataPath)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to read data file: %s", err))
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
return
}
dataHash := sha3.Sum256(data)
datasets = append(datasets, &manager.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes})
}
algoHash := sha3.Sum256(algo)
algoHash, err := internal.Checksum(algoPath)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
return
}
reqChan <- &manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_RunReq{
+34 -26
View File
@@ -10,15 +10,16 @@ import (
"encoding/pem"
"fmt"
"log"
"net"
"os"
"strconv"
"github.com/mdlayher/vsock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/qemu"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"golang.org/x/crypto/sha3"
"google.golang.org/protobuf/proto"
)
@@ -35,21 +36,19 @@ func main() {
}
attestedTLS := attestedTLSParam
algo, err := os.ReadFile(algoPath)
if err != nil {
log.Fatalf(fmt.Sprintf("failed to read algorithm file: %s", err))
}
data, err := os.ReadFile(dataPath)
if err != nil {
log.Fatalf(fmt.Sprintf("failed to read data file: %s", err))
}
pubKey, err := os.ReadFile(pubKeyFile)
if err != nil {
log.Fatalf(fmt.Sprintf("failed to read public key file: %s", err))
}
pubPem, _ := pem.Decode(pubKey)
algoHash := sha3.Sum256(algo)
dataHash := sha3.Sum256(data)
algoHash, err := internal.Checksum(algoPath)
if err != nil {
log.Fatalf(fmt.Sprintf("failed to calculate checksum: %s", err))
}
dataHash, err := internal.Checksum(dataPath)
if err != nil {
log.Fatalf(fmt.Sprintf("failed to calculate checksum: %s", err))
}
l, err := vsock.Listen(manager.ManagerVsockPort, nil)
if err != nil {
@@ -57,8 +56,8 @@ func main() {
}
ac := agent.Computation{
ID: "123",
Datasets: agent.Datasets{agent.Dataset{Hash: dataHash, UserKey: pubPem.Bytes}},
Algorithm: agent.Algorithm{Hash: algoHash, UserKey: pubPem.Bytes},
Datasets: agent.Datasets{agent.Dataset{Hash: [32]byte(dataHash), UserKey: pubPem.Bytes}},
Algorithm: agent.Algorithm{Hash: [32]byte(algoHash), UserKey: pubPem.Bytes},
ResultConsumers: []agent.ResultConsumer{{UserKey: pubPem.Bytes}},
AgentConfig: agent.AgentConfig{
LogLevel: "debug",
@@ -66,7 +65,9 @@ func main() {
AttestedTls: attestedTLS,
},
}
fmt.Println(SendAgentConfig(3, ac))
if err := SendAgentConfig(3, ac); err != nil {
log.Fatal(err)
}
for {
conn, err := l.Accept()
@@ -74,18 +75,7 @@ func main() {
log.Println(err)
continue
}
b := make([]byte, 1024)
n, err := conn.Read(b)
if err != nil {
log.Println(err)
continue
}
conn.Close()
var mes pkgmanager.ClientStreamMessage
if err := proto.Unmarshal(b[:n], &mes); err != nil {
log.Println(err)
}
fmt.Println(mes.String())
go handleConnections(conn)
}
}
@@ -109,3 +99,21 @@ func SendAgentConfig(cid uint32, ac agent.Computation) error {
}
return nil
}
func handleConnections(conn net.Conn) {
defer conn.Close()
for {
b := make([]byte, 1024)
n, err := conn.Read(b)
if err != nil {
log.Println(err)
return
}
var message pkgmanager.ClientStreamMessage
if err := proto.Unmarshal(b[:n], &message); err != nil {
log.Println(err)
return
}
fmt.Println(message.String())
}
}