Files
Sammy Kerata Oina 6169766666
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
NOISSUE - Fix agent startup issues (#605)
* Update attestationFromCert function to include ccPlatform parameter for enhanced attestation processing

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: migrate dependencies from supermq to magistrala and update build configurations

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: update project dependencies, repository source, and support TDX QuoteV5 attestation

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2026-06-11 17:08:24 +02:00

1253 lines
40 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package agent
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
sync "sync"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/events"
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
"github.com/ultravioletrs/cocos/agent/statemachine"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
"github.com/ultravioletrs/cocos/pkg/oci"
"github.com/ultravioletrs/cocos/pkg/resource"
"golang.org/x/crypto/sha3"
)
var _ Service = (*agentService)(nil)
//go:generate stringer -type=AgentState
type AgentState int
const (
Idle AgentState = iota
ReceivingManifest
ReceivingAlgorithm
ReceivingData
Running
ConsumingResults
Complete
Failed
)
//go:generate stringer -type=AgentEvent
type AgentEvent int
const (
Start AgentEvent = iota
ManifestReceived
AlgorithmReceived
DataReceived
RunComplete
ResultsConsumed
RunFailed
)
//go:generate stringer -type=Status
type Status uint8
const (
IdleState Status = iota
InProgress
Ready
Completed
Terminated
Warning
Starting
)
const (
algoFilePermission = 0o700
)
var (
ImaMeasurementsFilePath = "/sys/kernel/security/integrity/ima/ascii_runtime_measurements"
ImaPcrIndex = 10
)
func ensureDir(path string, mode os.FileMode) error {
info, err := os.Stat(path)
switch {
case err == nil:
if info.IsDir() {
return nil
}
if err := os.Remove(path); err != nil {
return fmt.Errorf("removing non-directory path %q: %w", path, err)
}
case os.IsNotExist(err):
// Continue and create it below.
default:
return fmt.Errorf("stating path %q: %w", path, err)
}
if err := os.MkdirAll(path, mode); err != nil {
return fmt.Errorf("creating directory %q: %w", path, err)
}
return nil
}
var (
// ErrMalformedEntity indicates malformed entity specification (e.g.
// invalid username or password).
ErrMalformedEntity = errors.New("malformed entity specification")
// ErrUnauthorizedAccess indicates missing or invalid credentials provided
// when accessing a protected resource.
ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided")
// ErrUndeclaredAlgorithm indicates algorithm was not declared in computation manifest.
ErrUndeclaredAlgorithm = errors.New("algorithm not declared in computation manifest")
// ErrUndeclaredDataset indicates dataset was not declared in computation manifest.
ErrUndeclaredDataset = errors.New("dataset not declared in computation manifest")
// ErrAllManifestItemsReceived indicates no new computation manifest items expected.
ErrAllManifestItemsReceived = errors.New("all expected manifest Items have been received")
// ErrUndeclaredConsumer indicates the consumer requesting results in not declared in computation manifest.
ErrUndeclaredConsumer = errors.New("result consumer is undeclared in computation manifest")
// ErrResultsNotReady indicates the computation results are not ready.
ErrResultsNotReady = errors.New("computation results are not yet ready")
// ErrStateNotReady agent received a request in the wrong state.
ErrStateNotReady = errors.New("agent not expecting this operation in the current state")
// ErrHashMismatch provided algorithm/dataset does not match hash in manifest.
ErrHashMismatch = errors.New("malformed data, hash does not match manifest")
// ErrFileNameMismatch provided dataset filename does not match filename in manifest.
ErrFileNameMismatch = errors.New("malformed data, filename does not match manifest")
// ErrAllResultsConsumed indicates all results have been consumed.
ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers")
// ErrAttestationFailed attestation failed.
ErrAttestationFailed = errors.New("failed to get raw quote")
// ErrAttestationVTpmFailed vTPM attestation failed.
ErrAttestationVTpmFailed = errors.New("failed to get vTPM quote")
// ErrFetchAzureToken azure token fetch failed.
ErrFetchAzureToken = errors.New("failed to get azure token")
// ErrAttType indicates that the attestation type that is requested does not exist or is not supported.
ErrAttestationType = errors.New("attestation type does not exist or is not supported")
)
// Service specifies an API that must be fullfiled by the domain service
// implementation, and all of its decorators (e.g. logging & metrics).
type Service interface {
InitComputation(ctx context.Context, cmp Computation) error
StopComputation(ctx context.Context) error
Algo(ctx context.Context, algorithm Algorithm) error
Data(ctx context.Context, dataset Dataset) error
Result(ctx context.Context) ([]byte, error)
Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error)
IMAMeasurements(ctx context.Context) ([]byte, []byte, error)
AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error)
State() string
}
type OCIClient interface {
PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error
ToDockerArchive(ctx context.Context, ociDir, destFile string) error
}
type agentService struct {
mu sync.Mutex
computation Computation // Holds the current computation request details.
runnerClient runner_client.Client
algoType string
algoArgs []string
algoRequirements []byte
algoReceived bool
result []byte // Stores the result of the computation.
sm statemachine.StateMachine // Manages the state transitions of the agent service.
runError error // Stores any error encountered during the computation run.
eventSvc events.Service // Service for publishing events related to computation.
attestationClient attestation_client.Client // Client for attestation service.
logger *slog.Logger // Logger for the agent service.
resultsConsumed bool // Indicates if the results have been consumed.
cancel context.CancelFunc // Cancels the computation context.
vmpl int // VMPL at which the Agent is running.
ociClient OCIClient
resourceRegistry *resource.Registry // Registry of resource downloaders (S3, HTTP, etc.)
}
var _ Service = (*agentService)(nil)
// New instantiates the agent service implementation.
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, runnerClient runner_client.Client, vmlp int) Service {
sm := statemachine.NewStateMachine(Idle)
ctx, cancel := context.WithCancel(ctx)
svc := &agentService{
sm: sm,
eventSvc: eventSvc,
attestationClient: attestationClient,
runnerClient: runnerClient,
logger: logger,
cancel: cancel,
vmpl: vmlp,
}
workDir := filepath.Join(os.TempDir(), "cocos-oci")
skopeoClient, err := oci.NewSkopeoClient(workDir)
if err != nil {
logger.Warn("failed to create Skopeo client", "error", err)
}
svc.ociClient = skopeoClient
// Initialize resource downloader registry with all supported source types.
reg := resource.NewRegistry()
if skopeoClient != nil {
reg.Register(resource.NewOCIDownloader(skopeoClient))
}
reg.Register(resource.NewHTTPSDownloader())
reg.Register(resource.NewHTTPDownloader())
reg.Register(resource.NewS3Downloader(""))
reg.Register(resource.NewGCSDownloader())
svc.resourceRegistry = reg
transitions := []statemachine.Transition{
{From: Idle, Event: Start, To: ReceivingManifest},
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
}
transitions = append(transitions, []statemachine.Transition{
{From: ReceivingAlgorithm, Event: RunFailed, To: Failed},
{From: ReceivingData, Event: RunFailed, To: Failed},
{From: Running, Event: RunComplete, To: ConsumingResults},
{From: Running, Event: RunFailed, To: Failed},
{From: ConsumingResults, Event: ResultsConsumed, To: Complete},
}...)
for _, t := range transitions {
sm.AddTransition(t)
}
sm.SetAction(ReceivingAlgorithm, svc.downloadAlgorithmIfRemote)
sm.SetAction(ReceivingData, svc.downloadDatasetsIfRemote)
sm.SetAction(Running, svc.runComputation)
sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String()))
sm.SetAction(Complete, svc.publishEvent(Completed.String()))
sm.SetAction(Failed, svc.publishEvent(Failed.String()))
go func() {
if err := sm.Start(ctx); err != nil {
logger.Error(err.Error())
}
}()
time.Sleep(100 * time.Millisecond)
sm.SendEvent(Start)
time.Sleep(100 * time.Millisecond)
return svc
}
func (as *agentService) State() string {
return as.sm.GetState().String()
}
func (as *agentService) InitComputation(ctx context.Context, cmp Computation) error {
if as.sm.GetState() != ReceivingManifest {
return ErrStateNotReady
}
defer as.sm.SendEvent(ManifestReceived)
as.mu.Lock()
defer as.mu.Unlock()
as.computation = cmp
if cmp.Algorithm != nil {
as.logger.Info("received computation manifest",
"computation_id", cmp.ID,
"algo_has_source", cmp.Algorithm.Source != nil,
"algo_kbs_enabled", cmp.Algorithm.KBS != nil && cmp.Algorithm.KBS.Enabled,
"algo_kbs_url", func() string {
if cmp.Algorithm.KBS != nil {
return cmp.Algorithm.KBS.URL
}
return ""
}(),
"dataset_count", len(cmp.Datasets))
if cmp.Algorithm.Source != nil {
as.logger.Info("algorithm remote source configured",
"url", cmp.Algorithm.Source.URL,
"kbs_resource_path", cmp.Algorithm.Source.KBSResourcePath,
"kbs_enabled", cmp.Algorithm.KBS != nil && cmp.Algorithm.KBS.Enabled,
"kbs_url", func() string {
if cmp.Algorithm.KBS != nil {
return cmp.Algorithm.KBS.URL
}
return ""
}())
} else {
as.logger.Info("algorithm remote source NOT configured - will wait for direct upload")
}
} else {
as.logger.Info("received computation manifest (no algorithm)",
"computation_id", cmp.ID,
"dataset_count", len(cmp.Datasets))
}
as.logger.Info("Global KBS is NOT USED (per-resource configuration only)")
for i, d := range cmp.Datasets {
if d.Source != nil {
as.logger.Info("dataset remote source configured",
"index", i,
"filename", d.Filename,
"url", d.Source.URL,
"kbs_resource_path", d.Source.KBSResourcePath,
"kbs_enabled", d.KBS != nil && d.KBS.Enabled,
"kbs_url", func() string {
if d.KBS != nil {
return d.KBS.URL
}
return ""
}())
}
}
transitions := []statemachine.Transition{}
if len(cmp.Datasets) == 0 {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
} else {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
}
for _, t := range transitions {
as.sm.AddTransition(t)
}
return nil
}
func (as *agentService) StopComputation(ctx context.Context) error {
as.mu.Lock()
defer as.mu.Unlock()
as.eventSvc.SendEvent(as.computation.ID, "Stopped", "Stopped", json.RawMessage{})
as.cancel()
if _, err := as.runnerClient.Stop(ctx, &runnerpb.StopRequest{ComputationId: as.computation.ID}); err != nil {
as.logger.Warn("failed to stop runner", "error", err)
// proceed to cleanup
}
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
return fmt.Errorf("error removing datasets directory: %v", err)
}
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
return fmt.Errorf("error removing results directory: %v", err)
}
if err := os.Remove("algo"); err != nil && !os.IsNotExist(err) {
as.logger.Warn("error removing algorithm file", "error", err)
}
as.sm.Reset(Idle)
as.computation = Computation{}
as.algoReceived = false
as.algoType = ""
as.algoArgs = nil
as.algoRequirements = nil
as.result = nil
as.runError = nil
as.resultsConsumed = false
ctx, cancel := context.WithCancel(ctx)
as.cancel = cancel
go func() {
if err := as.sm.Start(ctx); err != nil {
as.logger.Error(err.Error())
}
}()
time.Sleep(100 * time.Millisecond)
as.sm.SendEvent(Start)
time.Sleep(100 * time.Millisecond)
return nil
}
// downloadAlgorithmIfRemote automatically downloads the algorithm if it has a remote source.
// This is called as an action when entering the ReceivingAlgorithm state.
func (as *agentService) downloadAlgorithmIfRemote(state statemachine.State) {
as.publishEvent(InProgress.String())(state)
as.mu.Lock()
defer as.mu.Unlock()
// Check if algorithm should be downloaded from remote source
if as.computation.Algorithm == nil {
as.logger.Info("algorithm automatic download not triggered, (no algorithm in manifest)")
return
}
kbsEnabled := as.computation.Algorithm.KBS != nil && as.computation.Algorithm.KBS.Enabled
kbsURL := ""
if as.computation.Algorithm.KBS != nil {
kbsURL = as.computation.Algorithm.KBS.URL
}
as.logger.Info("checking if algorithm should be downloaded automatically",
"algo_has_source", as.computation.Algorithm.Source != nil,
"kbs_enabled", kbsEnabled)
// Check if algorithm should be downloaded from remote source
if as.computation.Algorithm.Source != nil && kbsEnabled {
as.logger.Info("downloading algorithm from remote source",
"url", as.computation.Algorithm.Source.URL,
"kbs_resource_path", as.computation.Algorithm.Source.KBSResourcePath,
"kbs_url", kbsURL)
// Use background context for download operation
ctx := context.Background()
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, kbsURL, "algorithm")
if err != nil {
as.runError = fmt.Errorf("failed to download and decrypt algorithm: %w", err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
// Verify hash
hash := sha3.Sum256(res.Data)
if hash != as.computation.Algorithm.Hash {
as.runError = fmt.Errorf("algorithm hash mismatch: expected %x, got %x", as.computation.Algorithm.Hash, hash)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
// Write algorithm to file
currentDir, err := os.Getwd()
if err != nil {
as.runError = fmt.Errorf("error getting current directory: %w", err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
// If a source directory is available (e.g. from OCI extraction), copy all files
if res.SourceDir != "" {
as.logger.Info("copying extracted algorithm directory", "src", res.SourceDir, "dst", currentDir)
// Simple recursive copy (using shell cp for simplicity and reliability on Linux)
// Ensure we copy contents of SourceDir into currentDir
// Simple recursive copy (using shell cp for simplicity and reliability on Linux)
// Ensure we copy contents of SourceDir into currentDir
cmd := exec.Command("cp", "-r", res.SourceDir+"/.", currentDir)
if out, err := cmd.CombinedOutput(); err != nil {
as.runError = fmt.Errorf("error copying algorithm directory: %v, output: %s", err, out)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
}
f, err := os.Create(filepath.Join(currentDir, "algo"))
if err != nil {
as.runError = fmt.Errorf("error creating algorithm file: %w", err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
if _, err := f.Write(res.Data); err != nil {
as.runError = fmt.Errorf("error writing algorithm to file: %w", err)
as.logger.Error(as.runError.Error())
f.Close()
as.sm.SendEvent(RunFailed)
return
}
if err := os.Chmod(f.Name(), algoFilePermission); err != nil {
as.runError = fmt.Errorf("error changing file permissions: %w", err)
as.logger.Error(as.runError.Error())
f.Close()
as.sm.SendEvent(RunFailed)
return
}
if err := f.Close(); err != nil {
as.runError = fmt.Errorf("error closing file: %w", err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
as.algoReceived = true
as.algoRequirements = res.Requirements // Store requirements for installation
// The initramfs may have already provisioned /cocos/datasets.
if err := ensureDir(algorithm.DatasetsDir, 0o755); err != nil {
as.runError = fmt.Errorf("error creating datasets directory: %w", err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
as.algoType = as.computation.Algorithm.AlgoType
if as.algoType == "" {
as.algoType = string(algorithm.AlgoTypeBin)
}
as.algoArgs = as.computation.Algorithm.AlgoArgs
as.logger.Info("algorithm downloaded and saved successfully", "type", as.algoType, "has_requirements", len(res.Requirements) > 0)
as.sm.SendEvent(AlgorithmReceived)
} else {
// If no remote source, do nothing - wait for direct upload via Algo() RPC call
as.logger.Info("algorithm automatic download not triggered, waiting for direct upload",
"reason", "no remote source or KBS not enabled")
}
}
// downloadDatasetsIfRemote automatically downloads datasets that have remote sources.
// This is called as an action when entering the ReceivingData state.
func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
as.publishEvent(InProgress.String())(state)
as.mu.Lock()
defer as.mu.Unlock()
// Check if any datasets should be downloaded from remote sources
hasRemoteDatasets := false
for _, d := range as.computation.Datasets {
kbsEnabled := d.KBS != nil && d.KBS.Enabled
if d.Source != nil && kbsEnabled {
hasRemoteDatasets = true
break
}
}
if !hasRemoteDatasets {
// No remote datasets, wait for direct uploads via Data() RPC calls
return
}
// Download all remote datasets
ctx := context.Background()
for i := len(as.computation.Datasets) - 1; i >= 0; i-- {
d := as.computation.Datasets[i]
kbsEnabled := d.KBS != nil && d.KBS.Enabled
kbsURL := ""
if d.KBS != nil {
kbsURL = d.KBS.URL
}
if d.Source != nil && kbsEnabled {
as.logger.Info("downloading dataset from remote source", "filename", d.Filename, "kbs_url", kbsURL)
res, err := as.downloadAndDecryptResource(ctx, d.Source, kbsURL, "dataset")
if err != nil {
as.runError = fmt.Errorf("failed to download and decrypt dataset %s: %w", d.Filename, err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
// Verify hash
hash := sha3.Sum256(res.Data)
if hash != d.Hash {
as.runError = fmt.Errorf("dataset %s hash mismatch: expected %x, got %x", d.Filename, d.Hash, hash)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
// Write dataset to file
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, d.Filename))
if err != nil {
as.logger.Error("error creating dataset file", "error", err, "filename", d.Filename)
as.sm.SendEvent(RunFailed)
return
}
if d.Decompress {
if err := internal.UnzipFromMemory(res.Data, algorithm.DatasetsDir); err != nil {
as.runError = fmt.Errorf("failed to unzip dataset %s: %w", d.Filename, err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
} else {
if _, err := f.Write(res.Data); err != nil {
as.logger.Error("error writing dataset to file", "error", err, "filename", d.Filename)
f.Close()
as.sm.SendEvent(RunFailed)
return
}
}
if err := f.Close(); err != nil {
as.logger.Error("error closing file", "error", err, "filename", d.Filename)
as.sm.SendEvent(RunFailed)
return
}
// Remove from pending datasets
as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1)
as.logger.Info("dataset downloaded and saved successfully", "filename", d.Filename)
}
}
// If all datasets are downloaded, send DataReceived event
if len(as.computation.Datasets) == 0 {
as.logger.Info("all datasets downloaded successfully")
as.sm.SendEvent(DataReceived)
}
// Otherwise, wait for remaining datasets to be uploaded via Data() RPC calls
}
// DecryptedResource holds the data and metadata of a downloaded and decrypted resource.
type DecryptedResource struct {
Data []byte
Requirements []byte
SourceDir string
}
// downloadAndDecryptResource downloads and decrypts a resource from various sources.
// For OCI images, Skopeo handles download and CoCo Keyprovider handles decryption automatically.
// For S3, GCS, HTTP/HTTPS: download + optional AES-256-GCM decryption with key from KBS.
func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, kbsURL, resourceType string) (*DecryptedResource, error) {
// Determine source type
sourceType := source.Type
if sourceType == "" {
sourceType = inferSourceType(source.URL)
if sourceType == "" {
return nil, fmt.Errorf("unsupported source URL format: %s (specify type explicitly or use a recognized URL scheme)", source.URL)
}
}
switch sourceType {
case resource.SourceTypeOCIImage:
return as.downloadAndDecryptOCIImage(ctx, source, kbsURL, resourceType)
case resource.SourceTypeS3, resource.SourceTypeGCS, resource.SourceTypeHTTPS, resource.SourceTypeHTTP:
return as.downloadAndDecryptGenericResource(ctx, source, sourceType, kbsURL, resourceType)
default:
return nil, fmt.Errorf("unsupported source type: %s", sourceType)
}
}
// inferSourceType infers the resource source type from the URL scheme.
func inferSourceType(u string) string {
if u == "" {
return ""
}
parsedURL, err := url.Parse(u)
if err != nil {
return ""
}
switch parsedURL.Scheme {
case "docker", "oci":
return resource.SourceTypeOCIImage
case "s3":
return resource.SourceTypeS3
case "gs":
return resource.SourceTypeGCS
case "https":
return resource.SourceTypeHTTPS
case "http":
return resource.SourceTypeHTTP
case "":
// No URL scheme (e.g., bare "docker.io/library/ubuntu:latest").
// Default to OCI Image if it looks like one (contains a slash).
if strings.Contains(u, "/") {
return resource.SourceTypeOCIImage
}
return ""
default:
// A scheme was parsed. But if it's not a known standard scheme,
// it might be a bare OCI reference like "ubuntu:latest" where "ubuntu" is parsed as the scheme.
// If there is no "://" and we have an opaque part (meaning there's a colon but no slashes),
// it's highly likely a bare image name.
if !strings.Contains(u, "://") && parsedURL.Opaque != "" {
return resource.SourceTypeOCIImage
}
return ""
}
}
// downloadAndDecryptGenericResource downloads a resource using the appropriate downloader
// from the registry and optionally decrypts it with AES-256-GCM using a key from KBS.
func (as *agentService) downloadAndDecryptGenericResource(ctx context.Context, source *ResourceSource, sourceType, kbsURL, resourceType string) (*DecryptedResource, error) {
as.logger.Info(fmt.Sprintf("downloading %s resource (type=%s url=%s encrypted=%t kbs_path=%s)",
resourceType, sourceType, source.URL, source.Encrypted, source.KBSResourcePath))
if as.resourceRegistry == nil {
return nil, fmt.Errorf("resource registry not initialized")
}
downloader, err := as.resourceRegistry.Get(sourceType)
if err != nil {
return nil, fmt.Errorf("no downloader for source type %s: %w", sourceType, err)
}
// Download to temporary file.
destPath := filepath.Join(os.TempDir(), "cocos-resources", resourceType, filepath.Base(source.URL))
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
if err := downloader.Download(ctx, source.URL, destPath); err != nil {
return nil, fmt.Errorf("failed to download resource from %s: %w", source.URL, err)
}
as.logger.Info("resource downloaded", "dest", destPath)
// Read the downloaded file.
data, err := os.ReadFile(destPath)
if err != nil {
return nil, fmt.Errorf("failed to read downloaded resource: %w", err)
}
// If encrypted, retrieve key from KBS and decrypt.
if source.Encrypted && source.KBSResourcePath != "" {
as.logger.Info("resource is encrypted, retrieving decryption key from KBS",
"kbs_path", source.KBSResourcePath,
"kbs_url", kbsURL)
key, err := as.getKeyFromKBS(ctx, kbsURL, source.KBSResourcePath)
if err != nil {
return nil, fmt.Errorf("failed to retrieve decryption key from KBS: %w", err)
}
plaintext, err := resource.DecryptData(data, key)
if err != nil {
return nil, fmt.Errorf("failed to decrypt resource: %w", err)
}
data = plaintext
as.logger.Info("resource decrypted successfully", "plaintext_size", len(data))
}
return &DecryptedResource{
Data: data,
}, nil
}
// getKeyFromKBS retrieves a decryption key from the Key Broker Service.
// It uses the Attestation Agent's GetResource capability to fetch the key
// after performing remote attestation.
func (as *agentService) getKeyFromKBS(ctx context.Context, kbsURL, resourcePath string) ([]byte, error) {
if kbsURL == "" {
return nil, fmt.Errorf("KBS not configured or not enabled")
}
// Construct KBS resource URL: kbs://<kbs_url>/<resource_path>
kbsResourceURL := fmt.Sprintf("%s/kbs/v0/resource/%s", kbsURL, resourcePath)
as.logger.Info("fetching key from KBS", "url", kbsResourceURL)
// Use a simple HTTP GET to KBS for now.
// In a full CoCo deployment, this would go through the Attestation Agent
// which performs attestation before KBS releases the key.
// For non-OCI resources, the AA/KBS handshake may need to be handled
// differently than via ocicrypt.
resp, err := kbsHTTPGet(ctx, kbsResourceURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch key from KBS at %s: %w", kbsResourceURL, err)
}
return resp, nil
}
// kbsHTTPGet performs an HTTP GET to the KBS endpoint.
func kbsHTTPGet(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("KBS returned status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return body, nil
}
// downloadAndDecryptOCIImage downloads and decrypts an OCI image using Skopeo and CoCo Keyprovider.
func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, kbsURL, resourceType string) (*DecryptedResource, error) {
as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s kbs_url=%s)",
source.URL, source.Encrypted, source.KBSResourcePath, kbsURL))
// Create Skopeo client
if as.ociClient == nil {
return nil, fmt.Errorf("OCI client not initialized")
}
uri := source.URL
// If the URI is just an image name without a transport scheme, default to docker://
if !strings.Contains(uri, "://") && !strings.HasPrefix(uri, "oci:") && !strings.HasPrefix(uri, "docker-archive:") && !strings.HasPrefix(uri, "dir:") {
uri = "docker://" + uri
}
// Create OCI resource source
ociSource := oci.ResourceSource{
Type: oci.ResourceTypeOCIImage,
URI: uri,
Encrypted: source.Encrypted,
KBSResourcePath: source.KBSResourcePath,
KBSURL: kbsURL,
}
// Pull and decrypt image
// CoCo Keyprovider will automatically handle decryption via ocicrypt
// Sanitize directory name to avoid Skopeo interpreting ':' as tag separator
sanitizedName := strings.ReplaceAll(filepath.Base(source.URL), ":", "_")
destDir := filepath.Join(os.TempDir(), "cocos-oci", "images", sanitizedName)
if err := as.ociClient.PullAndDecrypt(ctx, ociSource, destDir); err != nil {
return nil, fmt.Errorf("failed to pull and decrypt OCI image: %w", err)
}
as.logger.Info("OCI image downloaded and decrypted", "dest", destDir)
// Extract algorithm file from OCI layers
extractDir := filepath.Join(os.TempDir(), "cocos-oci", "extracted", sanitizedName)
var algorithmPath string
var requirementsPath string
var err error
var files []string
if resourceType == "algorithm" && as.computation.Algorithm != nil {
if as.computation.Algorithm.AlgoType == string(algorithm.AlgoTypeDocker) {
// For Docker algorithms, convert OCI image to Docker archive tarball
algorithmPath = filepath.Join(extractDir, "image.tar")
if err := os.MkdirAll(extractDir, 0o755); err != nil {
return nil, fmt.Errorf("failed to create extract directory: %w", err)
}
if err := as.ociClient.ToDockerArchive(ctx, destDir, algorithmPath); err != nil {
return nil, fmt.Errorf("failed to convert OCI image to Docker archive: %w", err)
}
as.logger.Info("OCI image converted to Docker archive", "path", algorithmPath)
files = []string{algorithmPath}
} else {
algorithmPath, requirementsPath, err = oci.ExtractAlgorithm(ctx, as.logger, destDir, extractDir, as.computation.Algorithm.AlgoType)
if err != nil {
return nil, fmt.Errorf("failed to extract algorithm from OCI image: %w", err)
}
as.logger.Info("algorithm extracted from OCI image", "path", algorithmPath)
files = []string{algorithmPath}
}
} else {
// Assume dataset
files, err = oci.ExtractDataset(destDir, extractDir)
if err != nil || len(files) == 0 {
return nil, fmt.Errorf("failed to extract dataset from OCI image: %w", err)
}
// Set algorithmPath to the first file for SourceDir calculation later
algorithmPath = files[0]
as.logger.Info("dataset extracted from OCI image", "num_files", len(files))
}
// Determine which path to hash based on extraction results
var hashPath string
// For algorithms, we always hash the specific algorithm file found.
// For datasets, if there's only one file, hash it directly.
// If multiple files, hash the directory (which zips it).
if len(files) == 1 {
hashPath = files[0]
} else {
hashPath = extractDir
}
// Calculate digest (matches internal.Checksum logic)
resourceData, _, err := internal.Digest(hashPath)
if err != nil {
return nil, fmt.Errorf("failed to calculate resource digest: %w", err)
}
// Read requirements file if found (only for algorithms)
var reqData []byte
if resourceType == "algorithm" {
if requirementsPath != "" {
reqData, err = os.ReadFile(requirementsPath)
if err != nil {
as.logger.Warn("failed to read requirements file", "path", requirementsPath, "error", err)
} else {
as.logger.Info("requirements.txt loaded", "size", len(reqData))
}
} else {
// Fallback: check if requirements.txt exists in the same directory as the algorithm
reqPath := filepath.Join(filepath.Dir(algorithmPath), "requirements.txt")
if data, err := os.ReadFile(reqPath); err == nil {
reqData = data
as.logger.Info("found requirements.txt via fallback", "size", len(data))
}
}
}
as.logger.Info("resource loaded from OCI", "type", resourceType, "size", len(resourceData), "hash_path", hashPath)
return &DecryptedResource{
Data: resourceData,
Requirements: reqData,
SourceDir: extractDir,
}, nil
}
func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
if as.sm.GetState() != ReceivingAlgorithm {
return ErrStateNotReady
}
as.mu.Lock()
defer as.mu.Unlock()
if as.algoReceived {
return ErrAllManifestItemsReceived
}
var algoData []byte
// Check if algorithm should be downloaded from remote source
if as.computation.Algorithm == nil {
return ErrUndeclaredAlgorithm
}
kbsEnabled := as.computation.Algorithm.KBS != nil && as.computation.Algorithm.KBS.Enabled
kbsURL := ""
if as.computation.Algorithm.KBS != nil {
kbsURL = as.computation.Algorithm.KBS.URL
}
if as.computation.Algorithm.Source != nil && kbsEnabled {
as.logger.Info("downloading algorithm from remote source", "kbs_url", kbsURL)
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, kbsURL, "algorithm")
if err != nil {
return fmt.Errorf("failed to download and decrypt algorithm: %w", err)
}
algoData = res.Data
as.algoRequirements = res.Requirements
} else {
// Use directly uploaded algorithm
algoData = algo.Algorithm
}
hash := sha3.Sum256(algoData)
if hash != as.computation.Algorithm.Hash {
return ErrHashMismatch
}
currentDir, err := os.Getwd()
if err != nil {
return fmt.Errorf("error getting current directory: %v", err)
}
f, err := os.Create(filepath.Join(currentDir, "algo"))
if err != nil {
return fmt.Errorf("error creating algorithm file: %v", err)
}
if _, err := f.Write(algoData); err != nil {
return fmt.Errorf("error writing algorithm to file: %v", err)
}
if err := os.Chmod(f.Name(), algoFilePermission); err != nil {
return fmt.Errorf("error changing file permissions: %v", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("error closing file: %v", err)
}
algoType := algorithm.AlgorithmTypeFromContext(ctx)
if algoType == "" {
algoType = string(algorithm.AlgoTypeBin)
}
args := algorithm.AlgorithmArgsFromContext(ctx)
as.algoType = algoType
as.algoArgs = args
as.algoRequirements = algo.Requirements
as.algoReceived = true
if err := ensureDir(algorithm.DatasetsDir, 0o755); err != nil {
return fmt.Errorf("error creating datasets directory: %v", err)
}
if as.algoReceived {
as.sm.SendEvent(AlgorithmReceived)
}
return nil
}
func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
if as.sm.GetState() != ReceivingData {
return ErrStateNotReady
}
as.mu.Lock()
defer as.mu.Unlock()
if len(as.computation.Datasets) == 0 {
return ErrAllManifestItemsReceived
}
var datasetData []byte
var datasetFilename string
// Check if any dataset should be downloaded from remote source
matchedIndex := -1
for i, d := range as.computation.Datasets {
kbsEnabled := d.KBS != nil && d.KBS.Enabled
kbsURL := ""
if d.KBS != nil {
kbsURL = d.KBS.URL
}
if d.Source != nil && kbsEnabled {
as.logger.Info("downloading dataset from remote source", "filename", d.Filename, "kbs_url", kbsURL)
downloadedData, err := as.downloadAndDecryptResource(ctx, d.Source, kbsURL, "dataset")
if err != nil {
return fmt.Errorf("failed to download and decrypt dataset: %w", err)
}
datasetData = downloadedData.Data
datasetFilename = d.Filename
matchedIndex = i
break
}
}
// If no remote dataset, use uploaded dataset
if matchedIndex == -1 {
datasetData = dataset.Dataset
datasetFilename = dataset.Filename
}
hash := sha3.Sum256(datasetData)
matched := false
for i, d := range as.computation.Datasets {
if hash == d.Hash {
if d.Filename != "" && d.Filename != datasetFilename {
return ErrFileNameMismatch
}
as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1)
if DecompressFromContext(ctx) {
if err := internal.UnzipFromMemory(datasetData, algorithm.DatasetsDir); err != nil {
return fmt.Errorf("error decompressing dataset: %v", err)
}
} else {
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, datasetFilename))
if err != nil {
return fmt.Errorf("error creating dataset file: %v", err)
}
if _, err := f.Write(datasetData); err != nil {
return fmt.Errorf("error writing dataset to file: %v", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("error closing file: %v", err)
}
}
matched = true
break
}
}
if !matched {
return ErrUndeclaredDataset
}
if len(as.computation.Datasets) == 0 {
defer as.sm.SendEvent(DataReceived)
}
return nil
}
func (as *agentService) Result(ctx context.Context) ([]byte, error) {
currentState := as.sm.GetState()
if currentState != ConsumingResults && currentState != Complete && currentState != Failed {
return []byte{}, ErrResultsNotReady
}
index, ok := IndexFromContext(ctx)
if !ok {
return []byte{}, ErrUndeclaredConsumer
}
as.mu.Lock()
defer as.mu.Unlock()
if index < 0 || index >= len(as.computation.ResultConsumers) {
return []byte{}, ErrUndeclaredConsumer
}
if !as.resultsConsumed && currentState == ConsumingResults {
as.resultsConsumed = true
defer as.sm.SendEvent(ResultsConsumed)
}
return as.result, as.runError
}
func (as *agentService) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
rawQuote, err := as.attestationClient.GetAttestation(ctx, reportData, nonce, attType)
if err != nil {
return []byte{}, errors.Wrap(ErrAttestationFailed, err)
}
return rawQuote, nil
}
func (as *agentService) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error) {
token, err := as.attestationClient.GetAzureToken(ctx, nonce)
if err != nil {
return []byte{}, err
}
return token, nil
}
func (as *agentService) runComputation(state statemachine.State) {
as.publishEvent(Starting.String())(state)
as.logger.Debug("computation run started")
defer func() {
as.mu.Lock()
defer as.mu.Unlock()
if as.runError != nil {
as.sm.SendEvent(RunFailed)
} else {
as.sm.SendEvent(RunComplete)
}
}()
// Read algo file
currentDir, _ := os.Getwd()
algoFile := filepath.Join(currentDir, "algo")
defer func() {
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
as.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error()))
}
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
}
if err := os.Remove(algoFile); err != nil && !os.IsNotExist(err) {
as.logger.Warn(fmt.Sprintf("error removing algorithm file: %s", err.Error()))
}
}()
if err := ensureDir(algorithm.ResultsDir, 0o755); err != nil {
as.mu.Lock()
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
as.mu.Unlock()
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
}
algoBytes, err := os.ReadFile(algoFile)
if err != nil {
as.mu.Lock()
as.runError = fmt.Errorf("failed to read algo file: %w", err)
as.mu.Unlock()
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
}
as.publishEvent(InProgress.String())(state)
// Call Runner
resp, err := as.runnerClient.Run(context.Background(), &runnerpb.RunRequest{
ComputationId: as.computation.ID,
AlgoType: as.algoType,
Algorithm: algoBytes,
Requirements: as.algoRequirements,
Args: as.algoArgs,
// Datasets implicit on shared FS
})
if err != nil {
as.mu.Lock()
as.runError = err
as.mu.Unlock()
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
}
if resp.Error != "" {
as.mu.Lock()
as.runError = errors.New(resp.Error)
as.mu.Unlock()
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error))
as.publishEvent(Failed.String())(state)
return
}
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
if err != nil {
as.mu.Lock()
as.runError = err
as.mu.Unlock()
as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
}
as.publishEvent(Completed.String())(state)
as.result = results
}
func (as *agentService) publishEvent(status string) statemachine.Action {
return func(state statemachine.State) {
as.eventSvc.SendEvent(as.computation.ID, state.String(), status, json.RawMessage{})
}
}
func (as *agentService) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) {
data, err := os.ReadFile(ImaMeasurementsFilePath)
if err != nil {
return nil, nil, fmt.Errorf("error reading Linux IMA measurements file: %s", err.Error())
}
pcr10, err := vtpm.GetPCRSHA1Value(ImaPcrIndex)
if err != nil {
return nil, nil, fmt.Errorf("error reading TPM PCR #10: %s", err.Error())
}
return data, pcr10, nil
}