Files
cocos/test/cvms/main.go
T
Sammy Kerata Oina 6169766666
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
NOISSUE - Fix agent startup issues (#605)
* Update attestationFromCert function to include ccPlatform parameter for enhanced attestation processing

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: migrate dependencies from supermq to magistrala and update build configurations

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: update project dependencies, repository source, and support TDX QuoteV5 attestation

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2026-06-11 17:08:24 +02:00

379 lines
12 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package main
import (
"context"
"encoding/hex"
"encoding/pem"
"flag"
"fmt"
"log"
"log/slog"
"os"
"strconv"
"strings"
mglog "github.com/absmach/magistrala/logger"
smqserver "github.com/absmach/magistrala/pkg/server"
grpcserver "github.com/absmach/magistrala/pkg/server/grpc"
"github.com/caarlos0/env/v11"
"github.com/ultravioletrs/cocos/agent/cvms"
cvmsgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
"github.com/ultravioletrs/cocos/internal"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection"
)
var _ cvmsgrpc.Service = (*svc)(nil)
const (
svcName = "cvms_test_server"
defaultPort = "7001"
)
var (
algoPath string
dataPathString string
dataPaths []string
attestedTLSString string
attestedTLS bool
pubKeyFile string
clientCAFile string
// Remote resource configuration.
algoKBSURL string
algoSourceURL string
algoSourceType string
algoKBSResourcePath string
datasetKBSURLs string
datasetSourceURLs string
datasetSourceType string
datasetKBSPaths string
algoType string
algoArgsString string
algoHash string
datasetTypeString string
datasetHash string
datasetDecompress string
)
type svc struct {
logger *slog.Logger
}
func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.SendFunc, authInfo credentials.AuthInfo) {
s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAddress))
pubKey, err := os.ReadFile(pubKeyFile)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to read public key file: %s", err))
return
}
pubPem, _ := pem.Decode(pubKey)
// Build datasets
var datasets []*cvms.Dataset
// Check if using remote datasets
var datasetURLs []string
var datasetKBSPathsList []string
var datasetKBSURLsList []string
if datasetSourceURLs != "" {
datasetURLs = strings.Split(datasetSourceURLs, ",")
}
if datasetKBSPaths != "" {
datasetKBSPathsList = strings.Split(datasetKBSPaths, ",")
}
if datasetKBSURLs != "" {
datasetKBSURLsList = strings.Split(datasetKBSURLs, ",")
}
var datasetDecompressList []bool
if datasetDecompress != "" {
parts := strings.Split(datasetDecompress, ",")
for _, p := range parts {
val, _ := strconv.ParseBool(p)
datasetDecompressList = append(datasetDecompressList, val)
}
}
// Parse dataset hash if provided
var dataHashBytes []byte
if datasetHash != "" {
var err error
dataHashBytes, err = hex.DecodeString(datasetHash)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to decode dataset hash: %s", err))
return
}
if len(dataHashBytes) != 32 {
s.logger.Error(fmt.Sprintf("dataset hash must be 32 bytes (SHA256), got %d", len(dataHashBytes)))
return
}
} else {
// Default to empty/zero hash
dataHashBytes = make([]byte, 32)
}
if len(datasetURLs) > 0 && len(datasetKBSPathsList) > 0 {
// Remote datasets mode
if len(datasetURLs) != len(datasetKBSPathsList) {
s.logger.Error("dataset source URLs and KBS paths must have the same count")
return
}
for i := 0; i < len(datasetURLs); i++ {
srcType := datasetSourceType
if srcType == "" {
srcType = "oci-image"
}
d := &cvms.Dataset{
Hash: dataHashBytes,
UserKey: pubPem.Bytes,
Filename: fmt.Sprintf("dataset_%d.csv", i),
Source: &cvms.Source{
Type: srcType,
Url: datasetURLs[i],
KbsResourcePath: datasetKBSPathsList[i],
Encrypted: datasetKBSPathsList[i] != "",
},
}
if len(datasetKBSURLsList) > i && datasetKBSURLsList[i] != "" {
d.Kbs = &cvms.KBSConfig{
Url: datasetKBSURLsList[i],
Enabled: true,
}
}
datasets = append(datasets, d)
if len(datasetDecompressList) > i {
datasets[len(datasets)-1].Decompress = datasetDecompressList[i]
}
}
} else {
// Direct upload mode - use local files
for _, dataPath := range dataPaths {
if _, err := os.Stat(dataPath); os.IsNotExist(err) {
s.logger.Error(fmt.Sprintf("data file does not exist: %s", dataPath))
return
}
dataHash, err := internal.ChecksumHex(dataPath)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
return
}
s.logger.Info("local dataset checksum", "path", dataPath, "hash", dataHash)
hashBytes, _ := hex.DecodeString(dataHash)
datasets = append(datasets, &cvms.Dataset{Hash: hashBytes, UserKey: pubPem.Bytes})
}
}
// Build algorithm
var algorithm *cvms.Algorithm
if algoSourceURL != "" && algoKBSResourcePath != "" {
// Remote algorithm mode
var algoHashBytes []byte
if algoHash != "" {
var err error
algoHashBytes, err = hex.DecodeString(algoHash)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to decode algo hash: %s", err))
return
}
if len(algoHashBytes) != 32 {
s.logger.Error(fmt.Sprintf("algo hash must be 32 bytes (SHA256), got %d", len(algoHashBytes)))
return
}
} else {
algoHashBytes = make([]byte, 32)
}
var algoArgs []string
if algoArgsString != "" {
algoArgs = strings.Split(algoArgsString, ",")
}
var algoSrcType string
if algoSourceType != "" {
algoSrcType = algoSourceType
} else {
algoSrcType = "oci-image"
}
algorithm = &cvms.Algorithm{
Hash: algoHashBytes,
UserKey: pubPem.Bytes,
AlgoType: algoType,
AlgoArgs: algoArgs,
Source: &cvms.Source{
Type: algoSrcType,
Url: algoSourceURL,
KbsResourcePath: algoKBSResourcePath,
Encrypted: algoKBSResourcePath != "",
},
}
if algoKBSURL != "" {
algorithm.Kbs = &cvms.KBSConfig{
Url: algoKBSURL,
Enabled: true,
}
}
} else {
// Direct upload mode - use local file
fileHash, err := internal.ChecksumHex(algoPath)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
return
}
s.logger.Info("local algorithm checksum", "path", algoPath, "hash", fileHash)
var algoArgs []string
if algoArgsString != "" {
algoArgs = strings.Split(algoArgsString, ",")
}
hashBytes, _ := hex.DecodeString(fileHash)
algorithm = &cvms.Algorithm{
Hash: hashBytes,
UserKey: pubPem.Bytes,
AlgoType: algoType,
AlgoArgs: algoArgs,
}
}
s.logger.Debug("sending computation run request")
if err := sendMessage(&cvms.ServerStreamMessage{
Message: &cvms.ServerStreamMessage_RunReq{
RunReq: &cvms.ComputationRunReq{
Id: "1",
Name: "sample computation",
Description: "sample descrption",
Datasets: datasets,
Algorithm: algorithm,
ResultConsumers: []*cvms.ResultConsumer{{UserKey: pubPem.Bytes}},
AgentConfig: &cvms.AgentConfig{
Port: "7002",
AttestedTls: attestedTLS,
ClientCaFile: clientCAFile,
},
},
},
}); err != nil {
s.logger.Error(fmt.Sprintf("failed to send run request: %s", err))
return
}
s.logger.Info("computation run request sent successfully")
// Keep the connection alive
<-ctx.Done()
s.logger.Info("connection closed")
}
func main() {
flagSet := flag.NewFlagSet("tests/cvms/main.go", flag.ContinueOnError)
flagSet.StringVar(&algoPath, "algo-path", "", "Path to the algorithm (for direct upload mode)")
flagSet.StringVar(&pubKeyFile, "public-key-path", "", "Path to the public key file")
flagSet.StringVar(&attestedTLSString, "attested-tls-bool", "", "Should aTLS be used, must be 'true' or 'false'")
flagSet.StringVar(&dataPathString, "data-paths", "", "Paths to data sources, list of string separated with commas (for direct upload mode)")
flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path")
// Remote resource flags
flagSet.StringVar(&algoKBSURL, "algo-kbs-url", "", "Algorithm-specific KBS endpoint URL")
flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (docker://..., s3://..., https://..., etc.)")
flagSet.StringVar(&algoSourceType, "algo-source-type", "", "Algorithm source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.")
flagSet.StringVar(&algoKBSResourcePath, "algo-kbs-path", "", "Algorithm KBS resource path (e.g., 'default/key/algo-key')")
flagSet.StringVar(&datasetKBSURLs, "dataset-kbs-urls", "", "Dataset-specific KBS endpoint URLs, comma-separated")
flagSet.StringVar(&datasetSourceURLs, "dataset-source-urls", "", "Dataset source URLs, comma-separated")
flagSet.StringVar(&datasetSourceType, "dataset-source-type", "", "Dataset source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.")
flagSet.StringVar(&datasetKBSPaths, "dataset-kbs-paths", "", "Dataset KBS resource paths, comma-separated")
flagSet.StringVar(&algoType, "algo-type", "", "Algorithm execution type")
flagSet.StringVar(&algoArgsString, "algo-args", "", "Algorithm arguments, comma-separated")
flagSet.StringVar(&algoHash, "algo-hash", "", "Algorithm SHA256 hash (hex string)")
flagSet.StringVar(&datasetTypeString, "dataset-type", "", "Dataset source type (deprecated, use --dataset-source-type)")
flagSet.StringVar(&datasetHash, "dataset-hash", "", "Dataset SHA256 hash (hex string)")
flagSet.StringVar(&datasetDecompress, "dataset-decompress", "", "Dataset decompression bools, comma-separated (e.g. true,false)")
flagSetParseError := flagSet.Parse(os.Args[1:])
if flagSetParseError != nil {
log.Fatalf("Error parsing flags: %v", flagSetParseError)
}
parsingError := !flagSet.Parsed()
var parsingErrorString strings.Builder
parsingErrorString.WriteString("\n")
// Validate that either algo-path OR (algo-source-url AND algo-kbs-path) is provided
if algoPath == "" && (algoSourceURL == "" || algoKBSResourcePath == "") {
parsingErrorString.WriteString("Either algo-path OR (algo-source-url AND algo-kbs-path) is required\n")
parsingError = true
}
if pubKeyFile == "" {
parsingErrorString.WriteString("Public key path is required\n")
parsingError = true
}
attestedTLSBoolValue, err := strconv.ParseBool(attestedTLSString)
if err != nil {
parsingErrorString.WriteString("Attested TLS flag is required and it must be a boolean value\n")
parsingError = true
attestedTLS = false
} else {
attestedTLS = attestedTLSBoolValue
}
if dataPathString != "" {
dataPaths = strings.Split(dataPathString, ",")
}
if parsingError {
parsingErrorString.WriteString("Usage :\n")
flagSet.SetOutput(&parsingErrorString)
flagSet.PrintDefaults()
log.Fatal(parsingErrorString.String())
}
ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)
incomingChan := make(chan *cvms.ClientStreamMessage)
logger, err := mglog.New(os.Stdout, "debug")
if err != nil {
log.Fatal(err.Error())
}
go func() {
for incoming := range incomingChan {
fmt.Println(incoming.Message)
}
}()
registerAgentServiceServer := func(srv *grpc.Server) {
reflection.Register(srv)
cvms.RegisterServiceServer(srv, cvmsgrpc.NewServer(incomingChan, &svc{logger: logger}))
}
grpcServerConfig := smqserver.Config{
Port: defaultPort,
}
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
return
}
gs := grpcserver.NewServer(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger)
g.Go(func() error {
return gs.Start()
})
g.Go(func() error {
return smqserver.StopSignalHandler(ctx, cancel, logger, svcName, gs)
})
if err := g.Wait(); err != nil {
logger.Error(fmt.Sprintf("%s service terminated: %s", svcName, err))
}
}