mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
6169766666
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* Update attestationFromCert function to include ccPlatform parameter for enhanced attestation processing Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: migrate dependencies from supermq to magistrala and update build configurations Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: update project dependencies, repository source, and support TDX QuoteV5 attestation Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
379 lines
12 KiB
Go
379 lines
12 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/hex"
|
|
"encoding/pem"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"log/slog"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
|
|
mglog "github.com/absmach/magistrala/logger"
|
|
smqserver "github.com/absmach/magistrala/pkg/server"
|
|
grpcserver "github.com/absmach/magistrala/pkg/server/grpc"
|
|
"github.com/caarlos0/env/v11"
|
|
"github.com/ultravioletrs/cocos/agent/cvms"
|
|
cvmsgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
|
|
"github.com/ultravioletrs/cocos/internal"
|
|
"golang.org/x/sync/errgroup"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/reflection"
|
|
)
|
|
|
|
var _ cvmsgrpc.Service = (*svc)(nil)
|
|
|
|
const (
|
|
svcName = "cvms_test_server"
|
|
defaultPort = "7001"
|
|
)
|
|
|
|
var (
|
|
algoPath string
|
|
dataPathString string
|
|
dataPaths []string
|
|
attestedTLSString string
|
|
attestedTLS bool
|
|
pubKeyFile string
|
|
clientCAFile string
|
|
// Remote resource configuration.
|
|
algoKBSURL string
|
|
algoSourceURL string
|
|
algoSourceType string
|
|
algoKBSResourcePath string
|
|
datasetKBSURLs string
|
|
datasetSourceURLs string
|
|
datasetSourceType string
|
|
datasetKBSPaths string
|
|
algoType string
|
|
algoArgsString string
|
|
algoHash string
|
|
datasetTypeString string
|
|
datasetHash string
|
|
datasetDecompress string
|
|
)
|
|
|
|
type svc struct {
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.SendFunc, authInfo credentials.AuthInfo) {
|
|
s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAddress))
|
|
|
|
pubKey, err := os.ReadFile(pubKeyFile)
|
|
if err != nil {
|
|
s.logger.Error(fmt.Sprintf("failed to read public key file: %s", err))
|
|
return
|
|
}
|
|
pubPem, _ := pem.Decode(pubKey)
|
|
|
|
// Build datasets
|
|
var datasets []*cvms.Dataset
|
|
|
|
// Check if using remote datasets
|
|
var datasetURLs []string
|
|
var datasetKBSPathsList []string
|
|
var datasetKBSURLsList []string
|
|
if datasetSourceURLs != "" {
|
|
datasetURLs = strings.Split(datasetSourceURLs, ",")
|
|
}
|
|
if datasetKBSPaths != "" {
|
|
datasetKBSPathsList = strings.Split(datasetKBSPaths, ",")
|
|
}
|
|
if datasetKBSURLs != "" {
|
|
datasetKBSURLsList = strings.Split(datasetKBSURLs, ",")
|
|
}
|
|
|
|
var datasetDecompressList []bool
|
|
if datasetDecompress != "" {
|
|
parts := strings.Split(datasetDecompress, ",")
|
|
for _, p := range parts {
|
|
val, _ := strconv.ParseBool(p)
|
|
datasetDecompressList = append(datasetDecompressList, val)
|
|
}
|
|
}
|
|
|
|
// Parse dataset hash if provided
|
|
var dataHashBytes []byte
|
|
if datasetHash != "" {
|
|
var err error
|
|
dataHashBytes, err = hex.DecodeString(datasetHash)
|
|
if err != nil {
|
|
s.logger.Error(fmt.Sprintf("failed to decode dataset hash: %s", err))
|
|
return
|
|
}
|
|
if len(dataHashBytes) != 32 {
|
|
s.logger.Error(fmt.Sprintf("dataset hash must be 32 bytes (SHA256), got %d", len(dataHashBytes)))
|
|
return
|
|
}
|
|
} else {
|
|
// Default to empty/zero hash
|
|
dataHashBytes = make([]byte, 32)
|
|
}
|
|
|
|
if len(datasetURLs) > 0 && len(datasetKBSPathsList) > 0 {
|
|
// Remote datasets mode
|
|
if len(datasetURLs) != len(datasetKBSPathsList) {
|
|
s.logger.Error("dataset source URLs and KBS paths must have the same count")
|
|
return
|
|
}
|
|
|
|
for i := 0; i < len(datasetURLs); i++ {
|
|
srcType := datasetSourceType
|
|
if srcType == "" {
|
|
srcType = "oci-image"
|
|
}
|
|
d := &cvms.Dataset{
|
|
Hash: dataHashBytes,
|
|
UserKey: pubPem.Bytes,
|
|
Filename: fmt.Sprintf("dataset_%d.csv", i),
|
|
Source: &cvms.Source{
|
|
Type: srcType,
|
|
Url: datasetURLs[i],
|
|
KbsResourcePath: datasetKBSPathsList[i],
|
|
Encrypted: datasetKBSPathsList[i] != "",
|
|
},
|
|
}
|
|
if len(datasetKBSURLsList) > i && datasetKBSURLsList[i] != "" {
|
|
d.Kbs = &cvms.KBSConfig{
|
|
Url: datasetKBSURLsList[i],
|
|
Enabled: true,
|
|
}
|
|
}
|
|
datasets = append(datasets, d)
|
|
if len(datasetDecompressList) > i {
|
|
datasets[len(datasets)-1].Decompress = datasetDecompressList[i]
|
|
}
|
|
}
|
|
} else {
|
|
// Direct upload mode - use local files
|
|
for _, dataPath := range dataPaths {
|
|
if _, err := os.Stat(dataPath); os.IsNotExist(err) {
|
|
s.logger.Error(fmt.Sprintf("data file does not exist: %s", dataPath))
|
|
return
|
|
}
|
|
dataHash, err := internal.ChecksumHex(dataPath)
|
|
if err != nil {
|
|
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
|
|
return
|
|
}
|
|
s.logger.Info("local dataset checksum", "path", dataPath, "hash", dataHash)
|
|
|
|
hashBytes, _ := hex.DecodeString(dataHash)
|
|
datasets = append(datasets, &cvms.Dataset{Hash: hashBytes, UserKey: pubPem.Bytes})
|
|
}
|
|
}
|
|
|
|
// Build algorithm
|
|
var algorithm *cvms.Algorithm
|
|
if algoSourceURL != "" && algoKBSResourcePath != "" {
|
|
// Remote algorithm mode
|
|
var algoHashBytes []byte
|
|
if algoHash != "" {
|
|
var err error
|
|
algoHashBytes, err = hex.DecodeString(algoHash)
|
|
if err != nil {
|
|
s.logger.Error(fmt.Sprintf("failed to decode algo hash: %s", err))
|
|
return
|
|
}
|
|
if len(algoHashBytes) != 32 {
|
|
s.logger.Error(fmt.Sprintf("algo hash must be 32 bytes (SHA256), got %d", len(algoHashBytes)))
|
|
return
|
|
}
|
|
} else {
|
|
algoHashBytes = make([]byte, 32)
|
|
}
|
|
|
|
var algoArgs []string
|
|
if algoArgsString != "" {
|
|
algoArgs = strings.Split(algoArgsString, ",")
|
|
}
|
|
|
|
var algoSrcType string
|
|
if algoSourceType != "" {
|
|
algoSrcType = algoSourceType
|
|
} else {
|
|
algoSrcType = "oci-image"
|
|
}
|
|
|
|
algorithm = &cvms.Algorithm{
|
|
Hash: algoHashBytes,
|
|
UserKey: pubPem.Bytes,
|
|
AlgoType: algoType,
|
|
AlgoArgs: algoArgs,
|
|
Source: &cvms.Source{
|
|
Type: algoSrcType,
|
|
Url: algoSourceURL,
|
|
KbsResourcePath: algoKBSResourcePath,
|
|
Encrypted: algoKBSResourcePath != "",
|
|
},
|
|
}
|
|
if algoKBSURL != "" {
|
|
algorithm.Kbs = &cvms.KBSConfig{
|
|
Url: algoKBSURL,
|
|
Enabled: true,
|
|
}
|
|
}
|
|
} else {
|
|
// Direct upload mode - use local file
|
|
fileHash, err := internal.ChecksumHex(algoPath)
|
|
if err != nil {
|
|
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
|
|
return
|
|
}
|
|
s.logger.Info("local algorithm checksum", "path", algoPath, "hash", fileHash)
|
|
|
|
var algoArgs []string
|
|
if algoArgsString != "" {
|
|
algoArgs = strings.Split(algoArgsString, ",")
|
|
}
|
|
|
|
hashBytes, _ := hex.DecodeString(fileHash)
|
|
algorithm = &cvms.Algorithm{
|
|
Hash: hashBytes,
|
|
UserKey: pubPem.Bytes,
|
|
AlgoType: algoType,
|
|
AlgoArgs: algoArgs,
|
|
}
|
|
}
|
|
|
|
s.logger.Debug("sending computation run request")
|
|
if err := sendMessage(&cvms.ServerStreamMessage{
|
|
Message: &cvms.ServerStreamMessage_RunReq{
|
|
RunReq: &cvms.ComputationRunReq{
|
|
Id: "1",
|
|
Name: "sample computation",
|
|
Description: "sample descrption",
|
|
Datasets: datasets,
|
|
Algorithm: algorithm,
|
|
ResultConsumers: []*cvms.ResultConsumer{{UserKey: pubPem.Bytes}},
|
|
AgentConfig: &cvms.AgentConfig{
|
|
Port: "7002",
|
|
AttestedTls: attestedTLS,
|
|
ClientCaFile: clientCAFile,
|
|
},
|
|
},
|
|
},
|
|
}); err != nil {
|
|
s.logger.Error(fmt.Sprintf("failed to send run request: %s", err))
|
|
return
|
|
}
|
|
s.logger.Info("computation run request sent successfully")
|
|
|
|
// Keep the connection alive
|
|
<-ctx.Done()
|
|
s.logger.Info("connection closed")
|
|
}
|
|
|
|
func main() {
|
|
flagSet := flag.NewFlagSet("tests/cvms/main.go", flag.ContinueOnError)
|
|
flagSet.StringVar(&algoPath, "algo-path", "", "Path to the algorithm (for direct upload mode)")
|
|
flagSet.StringVar(&pubKeyFile, "public-key-path", "", "Path to the public key file")
|
|
flagSet.StringVar(&attestedTLSString, "attested-tls-bool", "", "Should aTLS be used, must be 'true' or 'false'")
|
|
flagSet.StringVar(&dataPathString, "data-paths", "", "Paths to data sources, list of string separated with commas (for direct upload mode)")
|
|
flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path")
|
|
// Remote resource flags
|
|
flagSet.StringVar(&algoKBSURL, "algo-kbs-url", "", "Algorithm-specific KBS endpoint URL")
|
|
flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (docker://..., s3://..., https://..., etc.)")
|
|
flagSet.StringVar(&algoSourceType, "algo-source-type", "", "Algorithm source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.")
|
|
flagSet.StringVar(&algoKBSResourcePath, "algo-kbs-path", "", "Algorithm KBS resource path (e.g., 'default/key/algo-key')")
|
|
flagSet.StringVar(&datasetKBSURLs, "dataset-kbs-urls", "", "Dataset-specific KBS endpoint URLs, comma-separated")
|
|
flagSet.StringVar(&datasetSourceURLs, "dataset-source-urls", "", "Dataset source URLs, comma-separated")
|
|
flagSet.StringVar(&datasetSourceType, "dataset-source-type", "", "Dataset source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.")
|
|
flagSet.StringVar(&datasetKBSPaths, "dataset-kbs-paths", "", "Dataset KBS resource paths, comma-separated")
|
|
flagSet.StringVar(&algoType, "algo-type", "", "Algorithm execution type")
|
|
flagSet.StringVar(&algoArgsString, "algo-args", "", "Algorithm arguments, comma-separated")
|
|
flagSet.StringVar(&algoHash, "algo-hash", "", "Algorithm SHA256 hash (hex string)")
|
|
flagSet.StringVar(&datasetTypeString, "dataset-type", "", "Dataset source type (deprecated, use --dataset-source-type)")
|
|
flagSet.StringVar(&datasetHash, "dataset-hash", "", "Dataset SHA256 hash (hex string)")
|
|
flagSet.StringVar(&datasetDecompress, "dataset-decompress", "", "Dataset decompression bools, comma-separated (e.g. true,false)")
|
|
|
|
flagSetParseError := flagSet.Parse(os.Args[1:])
|
|
if flagSetParseError != nil {
|
|
log.Fatalf("Error parsing flags: %v", flagSetParseError)
|
|
}
|
|
|
|
parsingError := !flagSet.Parsed()
|
|
var parsingErrorString strings.Builder
|
|
|
|
parsingErrorString.WriteString("\n")
|
|
|
|
// Validate that either algo-path OR (algo-source-url AND algo-kbs-path) is provided
|
|
if algoPath == "" && (algoSourceURL == "" || algoKBSResourcePath == "") {
|
|
parsingErrorString.WriteString("Either algo-path OR (algo-source-url AND algo-kbs-path) is required\n")
|
|
parsingError = true
|
|
}
|
|
|
|
if pubKeyFile == "" {
|
|
parsingErrorString.WriteString("Public key path is required\n")
|
|
parsingError = true
|
|
}
|
|
|
|
attestedTLSBoolValue, err := strconv.ParseBool(attestedTLSString)
|
|
if err != nil {
|
|
parsingErrorString.WriteString("Attested TLS flag is required and it must be a boolean value\n")
|
|
parsingError = true
|
|
attestedTLS = false
|
|
} else {
|
|
attestedTLS = attestedTLSBoolValue
|
|
}
|
|
|
|
if dataPathString != "" {
|
|
dataPaths = strings.Split(dataPathString, ",")
|
|
}
|
|
|
|
if parsingError {
|
|
parsingErrorString.WriteString("Usage :\n")
|
|
flagSet.SetOutput(&parsingErrorString)
|
|
flagSet.PrintDefaults()
|
|
log.Fatal(parsingErrorString.String())
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
g, ctx := errgroup.WithContext(ctx)
|
|
incomingChan := make(chan *cvms.ClientStreamMessage)
|
|
|
|
logger, err := mglog.New(os.Stdout, "debug")
|
|
if err != nil {
|
|
log.Fatal(err.Error())
|
|
}
|
|
|
|
go func() {
|
|
for incoming := range incomingChan {
|
|
fmt.Println(incoming.Message)
|
|
}
|
|
}()
|
|
|
|
registerAgentServiceServer := func(srv *grpc.Server) {
|
|
reflection.Register(srv)
|
|
cvms.RegisterServiceServer(srv, cvmsgrpc.NewServer(incomingChan, &svc{logger: logger}))
|
|
}
|
|
grpcServerConfig := smqserver.Config{
|
|
Port: defaultPort,
|
|
}
|
|
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{}); err != nil {
|
|
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
|
|
return
|
|
}
|
|
|
|
gs := grpcserver.NewServer(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger)
|
|
|
|
g.Go(func() error {
|
|
return gs.Start()
|
|
})
|
|
|
|
g.Go(func() error {
|
|
return smqserver.StopSignalHandler(ctx, cancel, logger, svcName, gs)
|
|
})
|
|
|
|
if err := g.Wait(); err != nil {
|
|
logger.Error(fmt.Sprintf("%s service terminated: %s", svcName, err))
|
|
}
|
|
}
|