mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-22 20:00:18 +00:00
da31d76c94
CI / checkproto (push) Has been cancelled
CI / lint (push) Has been cancelled
Rust CI Pipeline / rust-check (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* feat(kbs): implement KBS client for attestation and resource retrieval - Added KBS client implementation in pkg/kbs/client.go with methods for attestation and resource retrieval. - Introduced necessary data structures for requests and responses. - Implemented error handling for various scenarios. test(kbs): add unit tests for KBS client - Created comprehensive tests for the KBS client in pkg/kbs/client_test.go. - Included tests for attestation success and failure cases, as well as resource retrieval. feat(registry): introduce HTTP and S3 registry implementations - Added HTTPRegistry for downloading resources over HTTP/HTTPS with retry logic in pkg/registry/http.go. - Implemented S3Registry for downloading resources from AWS S3 and S3-compatible services in pkg/registry/s3.go. - Included error handling and configuration options for both registries. chore(registry): define registry interface and configuration - Created registry interface and configuration struct in pkg/registry/registry.go. - Added default configuration settings for registry clients. docs(cvms): update README for CVMS server configuration and usage - Enhanced documentation for CVMS server with detailed command-line flags and usage examples. - Clarified direct upload and remote resource modes, including KBS integration. fix(cvms): integrate KBS for remote resource handling in main.go - Updated main.go to support remote datasets and algorithms using KBS. - Added validation for command-line flags to ensure proper configuration. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Move ifeq conditional outside define block in attestation-service.mk Make conditionals cannot be evaluated inside define...endef blocks when used as recipe bodies. Restructured to define the ATTESTATION_SERVICE_INSTALL_INIT_SYSTEMD block conditionally based on BR2_PACKAGE_CC_ATTESTATION_AGENT configuration. * feat: Implement remote resource downloading for algorithms and datasets using AWS S3/MinIO credentials. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add comprehensive documentation and agent support for testing remote resource download with KBS attestation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Improve agent logging for remote resource configuration and KBS status, and add a testing guide for remote resource downloads with KBS attestation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add a comprehensive guide for testing remote resource download with KBS attestation and update multiple package versions to a specific commit. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add failure transitions for resource reception states and a comprehensive guide for testing remote resource downloads with KBS attestation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Implement remote resource download with KBS attestation in the agent and add a comprehensive testing guide. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: Add comprehensive guide for testing remote resource download with KBS attestation and include a debug log in the attestation client. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Delegate KBS attestation and token retrieval to a new attestation-agent service and document remote resource testing. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * client fixes Signed-off-by: Sammy Oina <sammyoina@gmail.com> * raw evidence Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Build all Go files in cmd directories, not just main.go This fixes the issue where fetch_raw_evidence.go wasn't being included in the attestation-service build. * fix: Wrap binary evidence in JSON for KBS compatibility Fixes 'invalid character' error by wrapping raw binary evidence in a JSON structure with base64 encoding, as expected by KBS. * chore: Update buildroot packages toc28cefaeIncludes fixes for: 1. attestation-service build (including fetch_raw_evidence.go) 2. Agent KBS evidence format (wrapping binary in JSON) * fix: Implement KBS RCAR handshake with cookies Fixes 'cookie not found' error (401) from KBS by: 1. Adding CookieJar support to KBS client 2. Implementing GetChallenge() to perform /auth handshake and capture session cookie 3. Updating Agent to get challenge, decode nonce, and use it for evidence generation 4. Regenerating mocks * chore: Update buildroot packages tof6981ac5Includes KBS RCAR handshake fix (cookie support + GetChallenge loop) * fix: Update KBS client JSON tags to kebab-case Fixes deserialization error (401) from KBS by: 1. Using kebab-case (e.g. extra-params) for JSON tags as per protocol. 2. Initializing ExtraParams as empty object {} instead of null/omitted. * fix: Wrap attestation evidence in primary_evidence format Updates Agent to construct 'tee-evidence' payload with: - primary_evidence: containing the actual quote/data - additional_evidence: empty JSON object This matches the Confidential Containers KBS Attestation Protocol requirements. * fix: Update KBS protocol version to 0.4.0 KBS rejected 0.1.0 with a version mismatch error. Bumping to 0.4.0 to match server expectation. * fix: Generate ephemeral key for KBS RuntimeData Updates RuntimeData to include a valid ephemeral EC P-256 public key in JWK format, as required by the KBS RCAR protocol. Also fixes the KBS client struct to support TEEPubKey as an object. * fix: Update sample attestation quote to valid JSON The default attestation.bin was binary, but the KBS Sample Verifier expects a valid JSON quote containing 'svn' and 'report_data'. Updated the embedded bin file to contain this JSON structure. * fix: Generate dynamic JSON quote for Sample TEE in FetchRawEvidence The KBS Sample Verifier expects a JSON object with 'svn' and 'report_data'. Previously, we were returning raw binary data (reportData+nonce). This commit updates FetchRawEvidence to return a marshaled JSON structure with: - svn: "1" - report_data: base64(req.ReportData) * refactor: Delegate Sample Attestation to Provider Refactored sample attestation logic: - Moved JSON Quote generation into EmptyProvider (standalone mode). - Updated FetchRawEvidence to call provider.TeeAttestation instead of manual generation. This enables using the real CC Attestation Agent for UNSPECIFIED platform if configured. * feat: Add comprehensive debug logging and enforce CC AA usage Changes: - Updated EmptyProvider to return error instead of generating mock data This forces proper use of CC Attestation Agent's sample attester - Added detailed logging to attestation-service FetchRawEvidence: * Hex dump of evidence (first 200 bytes) * String preview of evidence * Total evidence length - Added detailed logging to agent service: * Raw evidence hex and string previews * KBS evidence JSON preview (first 500 bytes) * Evidence lengths at each transformation step This logging will help diagnose why KBS Sample Verifier is rejecting evidence. * fix: Enable CC AA by default and add attestation-service log forwarding Changes: - Set USE_CC_ATTESTATION_AGENT=true by default in systemd service - Added StandardOutput/StandardError to forward logs to /var/log/cocos/ - Updated HAL makefile to handle new default value - This ensures attestation-service uses CC AA's sample attester - Logs will now be visible in CVMS output for debugging * feat: Add gRPC log forwarding to attestation-service Implemented the same log forwarding mechanism used by the agent: - Added ProtoHandler to write logs to both stdout and logQueue - Connected to log client (/run/cocos/log.sock) for gRPC forwarding - Added goroutine to forward logs to CVMS via log client - Logs will now appear in CVMS output during computation runs This enables visibility into attestation-service debug output including: - CC AA connection status - Evidence generation details (hex dumps, string previews) - Any errors from providers * fix: Parse sample evidence JSON instead of base64-encoding it The attestation-service returns sample evidence as JSON: {"svn":"1","report_data":"base64..."} The agent was incorrectly base64-encoding this JSON string again. KBS Sample Verifier expects the parsed JSON object directly. Fixed by: - Parsing the JSON evidence from attestation-service - Passing the parsed object directly in primary_evidence.evidence - This matches what KBS Sample Verifier expects * debug: Increase KBS evidence logging preview to 1000 bytes Show the complete JSON structure being sent to KBS to debug the attestation failure. * debug: Add comprehensive CC AA configuration logging Added debug logs to show: - Whether CC AA is enabled in config - CC AA address being used - Connection success/failure - Which provider is ultimately selected - Warning when falling back to EmptyProvider This will help diagnose why EmptyProvider is being used instead of CC Attestation Agent. * debug: Add startup logging for log client connection Added log message to show if log client connection succeeds at attestation-service startup. This will help diagnose why logs aren't appearing in CVMS output. * feat: Add retry logic with exponential backoff to log client Added simple retry mechanism to handle concurrent log requests: - 3 retry attempts with exponential backoff (10ms, 20ms, 40ms) - Applies to both SendLog and SendEvent methods - Centralized in log client so all services benefit - Should eliminate 'failed to send log' errors from concurrent requests This fixes the issue where attestation-service logs weren't appearing in CVMS output due to dropped messages. * fix: Flatten sample evidence fields in primary_evidence for KBS KBS Sample Verifier expects svn and report_data at the top level of primary_evidence, not nested under an 'evidence' key. Changed structure from: {"primary_evidence": {"tee": "sample", "evidence": {"svn": "1", ...}}} To: {"primary_evidence": {"tee": "sample", "svn": "1", "report_data": "...", ...}} This matches what KBS expects when deserializing the Quote structure. * fix: Use sample quote directly as primary_evidence per KBS protocol According to KBS attestation protocol spec, for sample TEE type, primary_evidence should be the sample quote JSON directly: {"svn": "1", "report_data": "..."} Removed extra 'tee' and 'platform' fields that were causing KBS to fail deserializing the Quote structure. The 'tee' field is already sent in the Request payload during RCAR handshake. Refs: - https://github.com/confidential-containers/trustee/blob/main/kbs/docs/kbs_attestation_protocol.md - https://github.com/confidential-containers/guest-components/blob/main/attestation-agent/attester/src/sample/mod.rs * fix: Make CC AA required for sample attestation when configured When USE_CC_ATTESTATION_AGENT=true, attestation-service now requires AA to be available for NoCC/sample platform. This ensures sample evidence always comes from AA with the correct KBS format. Changes: - Error out if AA connection fails for NoCC platform when AA is configured - Only use EmptyProvider if AA is explicitly NOT configured - Prevents incorrect sample evidence format from EmptyProvider This ensures attestation-service delegates to AA for sample evidence generation instead of creating it itself. * fix: Implement proper RCAR protocol with tee-pubkey and runtime-data hash Fixed KBS attestation error 'REPORT_DATA is different from that in Sample Quote' Changes: 1. Generate ephemeral EC key pair BEFORE getting evidence from AA 2. Create runtime-data with nonce + tee-pubkey (JWK format) 3. Hash runtime-data (SHA-256) and use as report_data for AA 4. This binds the tee-pubkey to the TEE evidence per RCAR protocol The report_data in the evidence now matches what KBS expects: hash(runtime-data) instead of computation ID. This completes the full RCAR protocol implementation: - Request → Challenge → Attestation (with bound tee-pubkey) → Response * fix(agent): use simple nonce for Sample attestation report_data For Sample/NoCC attestation, use the raw nonce bytes directly as report_data instead of hashing runtime-data. This avoids JSON serialization mismatches with the KBS Sample verifier. Real TEEs (TDX/SNP) still use runtime-data hash binding to cryptographically bind the ephemeral tee-pubkey to the evidence. * fix(agent): use RFC 8785 canonical JSON for runtime-data hashing The KBS Sample attestation verifier (and likely others) expects the report_data to be the SHA-256 hash of the *canonical* JSON serialization (RFC 8785) of the runtime-data. Standard Go JSON marshaling does not guarantee key ordering, leading to hash mismatches. This change uses github.com/gowebpki/jcs to canonicalize the runtime-data before hashing, ensuring compatibility with the KBS RCAR implementation. Also reverted the temporary 'simple nonce' workaround. * feat(hal): add CoCo Keyprovider and Skopeo packages - Add coco-keyprovider buildroot package with systemd service - Add skopeo buildroot package for OCI image handling - Add ocicrypt_keyprovider.conf for encrypted image decryption - Update Config.in to include new packages This enables standard CoCo ecosystem integration for encrypted OCI images instead of custom S3/HTTP registry clients. * feat(oci): add OCI image handling package with Skopeo integration - Add pkg/oci/types.go with ResourceSource and ImageManifest types - Add pkg/oci/skopeo.go with Skopeo wrapper for pull/decrypt - Add pkg/oci/extract.go for extracting algorithms and datasets from layers This package provides OCI image handling using Skopeo and CoCo Keyprovider for encrypted image decryption, replacing custom S3/HTTP registry clients. * chore: regenerate protobuf files for updated cvms.proto * refactor(agent): replace S3/HTTP/KBS with OCI package - Remove pkg/kbs and pkg/registry imports - Add pkg/oci import for OCI image handling - Replace downloadAndDecryptResource with OCI-based implementation - Use Skopeo + CoCo Keyprovider for automatic decryption - Reduce code from ~240 lines to ~70 lines This eliminates custom KBS RCAR handshake, S3/HTTP registry clients, and manual decryption logic. CoCo Keyprovider handles all decryption automatically via ocicrypt protocol. * chore: remove obsolete pkg/kbs and pkg/registry packages - Delete pkg/kbs/ (custom KBS client, ~300 lines) - Delete pkg/registry/ (S3/HTTP registry clients, ~400 lines) - Remove unused imports from agent/service.go - Run go mod tidy to clean up dependencies These packages have been replaced by pkg/oci with Skopeo and CoCo Keyprovider for standard CoCo ecosystem integration. * fix(agent): update ResourceSource struct to include type and encryption fields Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix(hal): update CoCo Keyprovider to v0.16.0 and fix build path - Update version from v0.11.0 to v0.16.0 (matches attestation agent) - Fix install path: target is at repo root, not in coco_keyprovider subdir - This fixes the build error where coco_keyprovider binary wasn't found The cargo workspace in guest-components builds to a shared target/ directory at the repository root, not within each crate's subdirectory. * feat: Update remote resources testing guide to use kbs-client and coco-keyprovider for key management and encryption, enable insecure TLS for Skopeo, and enhance CVMS with Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update component versions, revise image encryption documentation, and sanitize OCI image paths for Skopeo compatibility. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add `decompress` option to Dataset and `algo_type`/`algo_args` to Algorithm protobuf messages, updating client, test, and build configurations. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Update multiple package versions and enhance OCI image extraction error reporting for missing algorithm files. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: Bump package versions, improve OCI image extraction debugging by returning seen files, and remove unused dataset type parsing from test code. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Migrate OCI extraction to use structured logging with `slog` and `context`, and update package versions. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Bump multiple component versions, add encrypted status for computation inputs and algorithms, and refine OCI layer extraction warnings. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * logging Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add `Encrypted` field to algorithm and dataset resource sources and update all component versions. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: update component versions, integrate coco-keyprovider service, and configure ocicrypt key provider. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: add support for KBS parameters and dataset/algorithm hash calculations in CVMS Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: update resource download and extraction logic to support requirements.txt and improve hash verification Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: Update dependencies, improve code style, and add GetRawEvidence to attestation client mocks. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor code structure for improved readability and maintainability Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: update golangci configuration to include errcheck for build path and remove unnecessary exclusions Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: streamline kernel command line handling in QEMU args construction Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: add attestation binary and update checksum tests and policy structure Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for attestation agent, attestation, log, crypto, OCI, and Skopeo clients - Implement tests for the attestation agent client including Unix socket and TCP address handling, token retrieval, and error scenarios. - Enhance attestation client tests to cover fetching raw evidence for various platforms (SNP, TDX, VTPM, SNPvTPM) and validate error handling. - Introduce log client tests to verify retry behavior for sending logs and events. - Create comprehensive tests for crypto package focusing on AES-GCM decryption, encrypted resource parsing, and key unwrapping. - Add tests for OCI package to validate algorithm and dataset extraction, including JSON serialization of OCILayout. - Implement Skopeo client tests to ensure proper functionality for image pulling, inspecting, and resource source handling. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: handle JSON marshal errors in test cases for decrypt and extract functions Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add comprehensive tests for algorithm and dataset extraction with various scenarios Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: replace hardcoded Python script content with constant variable Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: remove redundant mock expectation for SendAgentConfig in TestCreateVMWithAaKbsParams Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add tests for event sending failure, dataset extraction with path traversal, and Skopeo client behavior Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add tests for download and decryption of resources with various URL formats Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Introduce OCIClient interface for agent service to improve testability of OCI image operations and enhance related tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Change `get_uint64_from_tcb` to accept `TcbVersion` by value and use `u64::from` for type conversions. --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
939 lines
29 KiB
Go
939 lines
29 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
sync "sync"
|
|
"time"
|
|
|
|
"github.com/absmach/supermq/pkg/errors"
|
|
"github.com/ultravioletrs/cocos/agent/algorithm"
|
|
"github.com/ultravioletrs/cocos/agent/events"
|
|
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
|
|
"github.com/ultravioletrs/cocos/agent/statemachine"
|
|
"github.com/ultravioletrs/cocos/internal"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
|
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
|
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
|
|
"github.com/ultravioletrs/cocos/pkg/oci"
|
|
"golang.org/x/crypto/sha3"
|
|
)
|
|
|
|
var _ Service = (*agentService)(nil)
|
|
|
|
//go:generate stringer -type=AgentState
|
|
type AgentState int
|
|
|
|
const (
|
|
Idle AgentState = iota
|
|
ReceivingManifest
|
|
ReceivingAlgorithm
|
|
ReceivingData
|
|
Running
|
|
ConsumingResults
|
|
Complete
|
|
Failed
|
|
)
|
|
|
|
//go:generate stringer -type=AgentEvent
|
|
type AgentEvent int
|
|
|
|
const (
|
|
Start AgentEvent = iota
|
|
ManifestReceived
|
|
AlgorithmReceived
|
|
DataReceived
|
|
RunComplete
|
|
ResultsConsumed
|
|
RunFailed
|
|
)
|
|
|
|
//go:generate stringer -type=Status
|
|
type Status uint8
|
|
|
|
const (
|
|
IdleState Status = iota
|
|
InProgress
|
|
Ready
|
|
Completed
|
|
Terminated
|
|
Warning
|
|
Starting
|
|
)
|
|
|
|
const (
|
|
algoFilePermission = 0o700
|
|
)
|
|
|
|
var (
|
|
ImaMeasurementsFilePath = "/sys/kernel/security/integrity/ima/ascii_runtime_measurements"
|
|
ImaPcrIndex = 10
|
|
)
|
|
|
|
var (
|
|
// ErrMalformedEntity indicates malformed entity specification (e.g.
|
|
// invalid username or password).
|
|
ErrMalformedEntity = errors.New("malformed entity specification")
|
|
// ErrUnauthorizedAccess indicates missing or invalid credentials provided
|
|
// when accessing a protected resource.
|
|
ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided")
|
|
// ErrUndeclaredAlgorithm indicates algorithm was not declared in computation manifest.
|
|
ErrUndeclaredDataset = errors.New("dataset not declared in computation manifest")
|
|
// ErrAllManifestItemsReceived indicates no new computation manifest items expected.
|
|
ErrAllManifestItemsReceived = errors.New("all expected manifest Items have been received")
|
|
// ErrUndeclaredConsumer indicates the consumer requesting results in not declared in computation manifest.
|
|
ErrUndeclaredConsumer = errors.New("result consumer is undeclared in computation manifest")
|
|
// ErrResultsNotReady indicates the computation results are not ready.
|
|
ErrResultsNotReady = errors.New("computation results are not yet ready")
|
|
// ErrStateNotReady agent received a request in the wrong state.
|
|
ErrStateNotReady = errors.New("agent not expecting this operation in the current state")
|
|
// ErrHashMismatch provided algorithm/dataset does not match hash in manifest.
|
|
ErrHashMismatch = errors.New("malformed data, hash does not match manifest")
|
|
// ErrFileNameMismatch provided dataset filename does not match filename in manifest.
|
|
ErrFileNameMismatch = errors.New("malformed data, filename does not match manifest")
|
|
// ErrAllResultsConsumed indicates all results have been consumed.
|
|
ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers")
|
|
// ErrAttestationFailed attestation failed.
|
|
ErrAttestationFailed = errors.New("failed to get raw quote")
|
|
// ErrAttestationVTpmFailed vTPM attestation failed.
|
|
ErrAttestationVTpmFailed = errors.New("failed to get vTPM quote")
|
|
// ErrFetchAzureToken azure token fetch failed.
|
|
ErrFetchAzureToken = errors.New("failed to get azure token")
|
|
// ErrAttType indicates that the attestation type that is requested does not exist or is not supported.
|
|
ErrAttestationType = errors.New("attestation type does not exist or is not supported")
|
|
)
|
|
|
|
// Service specifies an API that must be fullfiled by the domain service
|
|
// implementation, and all of its decorators (e.g. logging & metrics).
|
|
type Service interface {
|
|
InitComputation(ctx context.Context, cmp Computation) error
|
|
StopComputation(ctx context.Context) error
|
|
Algo(ctx context.Context, algorithm Algorithm) error
|
|
Data(ctx context.Context, dataset Dataset) error
|
|
Result(ctx context.Context) ([]byte, error)
|
|
Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error)
|
|
IMAMeasurements(ctx context.Context) ([]byte, []byte, error)
|
|
AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error)
|
|
State() string
|
|
}
|
|
|
|
type OCIClient interface {
|
|
PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error
|
|
}
|
|
|
|
type agentService struct {
|
|
mu sync.Mutex
|
|
computation Computation // Holds the current computation request details.
|
|
runnerClient runner_client.Client
|
|
algoType string
|
|
algoArgs []string
|
|
algoRequirements []byte
|
|
algoReceived bool
|
|
result []byte // Stores the result of the computation.
|
|
sm statemachine.StateMachine // Manages the state transitions of the agent service.
|
|
runError error // Stores any error encountered during the computation run.
|
|
eventSvc events.Service // Service for publishing events related to computation.
|
|
attestationClient attestation_client.Client // Client for attestation service.
|
|
logger *slog.Logger // Logger for the agent service.
|
|
resultsConsumed bool // Indicates if the results have been consumed.
|
|
cancel context.CancelFunc // Cancels the computation context.
|
|
vmpl int // VMPL at which the Agent is running.
|
|
ociClient OCIClient
|
|
}
|
|
|
|
var _ Service = (*agentService)(nil)
|
|
|
|
// New instantiates the agent service implementation.
|
|
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, runnerClient runner_client.Client, vmlp int) Service {
|
|
sm := statemachine.NewStateMachine(Idle)
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
svc := &agentService{
|
|
sm: sm,
|
|
eventSvc: eventSvc,
|
|
attestationClient: attestationClient,
|
|
runnerClient: runnerClient,
|
|
logger: logger,
|
|
cancel: cancel,
|
|
vmpl: vmlp,
|
|
}
|
|
|
|
workDir := filepath.Join(os.TempDir(), "cocos-oci")
|
|
skopeoClient, err := oci.NewSkopeoClient(workDir)
|
|
if err != nil {
|
|
logger.Warn("failed to create Skopeo client", "error", err)
|
|
}
|
|
svc.ociClient = skopeoClient
|
|
|
|
transitions := []statemachine.Transition{
|
|
{From: Idle, Event: Start, To: ReceivingManifest},
|
|
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
|
|
}
|
|
|
|
transitions = append(transitions, []statemachine.Transition{
|
|
{From: ReceivingAlgorithm, Event: RunFailed, To: Failed},
|
|
{From: ReceivingData, Event: RunFailed, To: Failed},
|
|
{From: Running, Event: RunComplete, To: ConsumingResults},
|
|
{From: Running, Event: RunFailed, To: Failed},
|
|
{From: ConsumingResults, Event: ResultsConsumed, To: Complete},
|
|
}...)
|
|
|
|
for _, t := range transitions {
|
|
sm.AddTransition(t)
|
|
}
|
|
|
|
sm.SetAction(ReceivingAlgorithm, svc.downloadAlgorithmIfRemote)
|
|
sm.SetAction(ReceivingData, svc.downloadDatasetsIfRemote)
|
|
sm.SetAction(Running, svc.runComputation)
|
|
sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String()))
|
|
sm.SetAction(Complete, svc.publishEvent(Completed.String()))
|
|
sm.SetAction(Failed, svc.publishEvent(Failed.String()))
|
|
|
|
go func() {
|
|
if err := sm.Start(ctx); err != nil {
|
|
logger.Error(err.Error())
|
|
}
|
|
}()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
sm.SendEvent(Start)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
return svc
|
|
}
|
|
|
|
func (as *agentService) State() string {
|
|
return as.sm.GetState().String()
|
|
}
|
|
|
|
func (as *agentService) InitComputation(ctx context.Context, cmp Computation) error {
|
|
if as.sm.GetState() != ReceivingManifest {
|
|
return ErrStateNotReady
|
|
}
|
|
defer as.sm.SendEvent(ManifestReceived)
|
|
|
|
as.mu.Lock()
|
|
defer as.mu.Unlock()
|
|
|
|
as.computation = cmp
|
|
|
|
// Debug: Log manifest details
|
|
as.logger.Info("received computation manifest",
|
|
"computation_id", cmp.ID,
|
|
"kbs_enabled", cmp.KBS.Enabled,
|
|
"kbs_url", cmp.KBS.URL,
|
|
"algo_has_source", cmp.Algorithm.Source != nil,
|
|
"dataset_count", len(cmp.Datasets))
|
|
|
|
if cmp.Algorithm.Source != nil {
|
|
as.logger.Info("algorithm remote source configured",
|
|
"url", cmp.Algorithm.Source.URL,
|
|
"kbs_resource_path", cmp.Algorithm.Source.KBSResourcePath)
|
|
} else {
|
|
as.logger.Info("algorithm remote source NOT configured - will wait for direct upload")
|
|
}
|
|
|
|
if cmp.KBS.Enabled {
|
|
as.logger.Info("KBS is ENABLED", "url", cmp.KBS.URL)
|
|
} else {
|
|
as.logger.Info("KBS is NOT ENABLED")
|
|
}
|
|
|
|
for i, d := range cmp.Datasets {
|
|
if d.Source != nil {
|
|
as.logger.Info("dataset remote source configured",
|
|
"index", i,
|
|
"filename", d.Filename,
|
|
"url", d.Source.URL,
|
|
"kbs_resource_path", d.Source.KBSResourcePath)
|
|
}
|
|
}
|
|
|
|
transitions := []statemachine.Transition{}
|
|
|
|
if len(cmp.Datasets) == 0 {
|
|
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
|
|
} else {
|
|
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
|
|
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
|
|
}
|
|
|
|
for _, t := range transitions {
|
|
as.sm.AddTransition(t)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (as *agentService) StopComputation(ctx context.Context) error {
|
|
as.mu.Lock()
|
|
defer as.mu.Unlock()
|
|
|
|
as.eventSvc.SendEvent(as.computation.ID, "Stopped", "Stopped", json.RawMessage{})
|
|
|
|
as.cancel()
|
|
|
|
if _, err := as.runnerClient.Stop(ctx, &runnerpb.StopRequest{ComputationId: as.computation.ID}); err != nil {
|
|
as.logger.Warn("failed to stop runner", "error", err)
|
|
// proceed to cleanup
|
|
}
|
|
|
|
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
|
|
return fmt.Errorf("error removing datasets directory: %v", err)
|
|
}
|
|
|
|
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
|
|
return fmt.Errorf("error removing results directory: %v", err)
|
|
}
|
|
|
|
as.sm.Reset(Idle)
|
|
|
|
as.computation = Computation{}
|
|
as.algoReceived = false
|
|
as.algoType = ""
|
|
as.algoArgs = nil
|
|
as.algoRequirements = nil
|
|
as.result = nil
|
|
as.runError = nil
|
|
as.resultsConsumed = false
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
as.cancel = cancel
|
|
|
|
go func() {
|
|
if err := as.sm.Start(ctx); err != nil {
|
|
as.logger.Error(err.Error())
|
|
}
|
|
}()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
as.sm.SendEvent(Start)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
return nil
|
|
}
|
|
|
|
// downloadAlgorithmIfRemote automatically downloads the algorithm if it has a remote source.
|
|
// This is called as an action when entering the ReceivingAlgorithm state.
|
|
func (as *agentService) downloadAlgorithmIfRemote(state statemachine.State) {
|
|
as.publishEvent(InProgress.String())(state)
|
|
|
|
as.mu.Lock()
|
|
defer as.mu.Unlock()
|
|
|
|
// Debug: Log decision point
|
|
as.logger.Info("checking if algorithm should be downloaded automatically",
|
|
"algo_has_source", as.computation.Algorithm.Source != nil,
|
|
"kbs_enabled", as.computation.KBS.Enabled)
|
|
|
|
// Check if algorithm should be downloaded from remote source
|
|
if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled {
|
|
as.logger.Info("downloading algorithm from remote source",
|
|
"url", as.computation.Algorithm.Source.URL,
|
|
"kbs_resource_path", as.computation.Algorithm.Source.KBSResourcePath)
|
|
|
|
// Use background context for download operation
|
|
ctx := context.Background()
|
|
|
|
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm")
|
|
if err != nil {
|
|
as.runError = fmt.Errorf("failed to download and decrypt algorithm: %w", err)
|
|
as.logger.Error(as.runError.Error())
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
// Verify hash
|
|
hash := sha3.Sum256(res.Data)
|
|
if hash != as.computation.Algorithm.Hash {
|
|
as.runError = fmt.Errorf("algorithm hash mismatch: expected %x, got %x", as.computation.Algorithm.Hash, hash)
|
|
as.logger.Error(as.runError.Error())
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
// Write algorithm to file
|
|
currentDir, err := os.Getwd()
|
|
if err != nil {
|
|
as.runError = fmt.Errorf("error getting current directory: %w", err)
|
|
as.logger.Error(as.runError.Error())
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
// If a source directory is available (e.g. from OCI extraction), copy all files
|
|
if res.SourceDir != "" {
|
|
as.logger.Info("copying extracted algorithm directory", "src", res.SourceDir, "dst", currentDir)
|
|
// Simple recursive copy (using shell cp for simplicity and reliability on Linux)
|
|
// Ensure we copy contents of SourceDir into currentDir
|
|
// Simple recursive copy (using shell cp for simplicity and reliability on Linux)
|
|
// Ensure we copy contents of SourceDir into currentDir
|
|
cmd := exec.Command("cp", "-r", res.SourceDir+"/.", currentDir)
|
|
if out, err := cmd.CombinedOutput(); err != nil {
|
|
as.runError = fmt.Errorf("error copying algorithm directory: %v, output: %s", err, out)
|
|
as.logger.Error(as.runError.Error())
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
}
|
|
|
|
f, err := os.Create(filepath.Join(currentDir, "algo"))
|
|
if err != nil {
|
|
as.runError = fmt.Errorf("error creating algorithm file: %w", err)
|
|
as.logger.Error(as.runError.Error())
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
if _, err := f.Write(res.Data); err != nil {
|
|
as.runError = fmt.Errorf("error writing algorithm to file: %w", err)
|
|
as.logger.Error(as.runError.Error())
|
|
f.Close()
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
if err := os.Chmod(f.Name(), algoFilePermission); err != nil {
|
|
as.runError = fmt.Errorf("error changing file permissions: %w", err)
|
|
as.logger.Error(as.runError.Error())
|
|
f.Close()
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
if err := f.Close(); err != nil {
|
|
as.runError = fmt.Errorf("error closing file: %w", err)
|
|
as.logger.Error(as.runError.Error())
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
as.algoReceived = true
|
|
as.algoRequirements = res.Requirements // Store requirements for installation
|
|
|
|
// Create datasets directory
|
|
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
|
|
as.runError = fmt.Errorf("error creating datasets directory: %w", err)
|
|
as.logger.Error(as.runError.Error())
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
as.algoType = as.computation.Algorithm.AlgoType
|
|
if as.algoType == "" {
|
|
as.algoType = string(algorithm.AlgoTypeBin)
|
|
}
|
|
as.algoArgs = as.computation.Algorithm.AlgoArgs
|
|
|
|
as.logger.Info("algorithm downloaded and saved successfully", "type", as.algoType, "has_requirements", len(res.Requirements) > 0)
|
|
as.sm.SendEvent(AlgorithmReceived)
|
|
} else {
|
|
// If no remote source, do nothing - wait for direct upload via Algo() RPC call
|
|
as.logger.Info("algorithm automatic download not triggered, waiting for direct upload",
|
|
"reason", "no remote source or KBS not enabled")
|
|
}
|
|
}
|
|
|
|
// downloadDatasetsIfRemote automatically downloads datasets that have remote sources.
|
|
// This is called as an action when entering the ReceivingData state.
|
|
func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
|
|
as.publishEvent(InProgress.String())(state)
|
|
|
|
as.mu.Lock()
|
|
defer as.mu.Unlock()
|
|
|
|
// Check if any datasets should be downloaded from remote sources
|
|
hasRemoteDatasets := false
|
|
for _, d := range as.computation.Datasets {
|
|
if d.Source != nil && as.computation.KBS.Enabled {
|
|
hasRemoteDatasets = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !hasRemoteDatasets {
|
|
// No remote datasets, wait for direct uploads via Data() RPC calls
|
|
return
|
|
}
|
|
|
|
// Download all remote datasets
|
|
ctx := context.Background()
|
|
for i := len(as.computation.Datasets) - 1; i >= 0; i-- {
|
|
d := as.computation.Datasets[i]
|
|
if d.Source != nil && as.computation.KBS.Enabled {
|
|
as.logger.Info("downloading dataset from remote source", "filename", d.Filename)
|
|
|
|
res, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset")
|
|
if err != nil {
|
|
as.logger.Error("failed to download and decrypt dataset", "error", err, "filename", d.Filename)
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
// Verify hash
|
|
hash := sha3.Sum256(res.Data)
|
|
if hash != d.Hash {
|
|
as.logger.Error("dataset hash mismatch", "filename", d.Filename)
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
// Write dataset to file
|
|
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, d.Filename))
|
|
if err != nil {
|
|
as.logger.Error("error creating dataset file", "error", err, "filename", d.Filename)
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
if d.Decompress {
|
|
if err := internal.UnzipFromMemory(res.Data, algorithm.DatasetsDir); err != nil {
|
|
as.logger.Error("error decompressing dataset", "error", err, "filename", d.Filename)
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
} else {
|
|
if _, err := f.Write(res.Data); err != nil {
|
|
as.logger.Error("error writing dataset to file", "error", err, "filename", d.Filename)
|
|
f.Close()
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
}
|
|
|
|
if err := f.Close(); err != nil {
|
|
as.logger.Error("error closing file", "error", err, "filename", d.Filename)
|
|
as.sm.SendEvent(RunFailed)
|
|
return
|
|
}
|
|
|
|
// Remove from pending datasets
|
|
as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1)
|
|
as.logger.Info("dataset downloaded and saved successfully", "filename", d.Filename)
|
|
}
|
|
}
|
|
|
|
// If all datasets are downloaded, send DataReceived event
|
|
if len(as.computation.Datasets) == 0 {
|
|
as.logger.Info("all datasets downloaded successfully")
|
|
as.sm.SendEvent(DataReceived)
|
|
}
|
|
// Otherwise, wait for remaining datasets to be uploaded via Data() RPC calls
|
|
}
|
|
|
|
// DecryptedResource holds the data and metadata of a downloaded and decrypted resource.
|
|
type DecryptedResource struct {
|
|
Data []byte
|
|
Requirements []byte
|
|
SourceDir string
|
|
}
|
|
|
|
// downloadAndDecryptResource downloads and decrypts a resource using OCI images and CoCo Keyprovider.
|
|
// For OCI images, Skopeo handles download and CoCo Keyprovider handles decryption automatically.
|
|
func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) {
|
|
// Determine source type
|
|
sourceType := source.Type
|
|
if sourceType == "" {
|
|
// Infer from URL
|
|
if strings.HasPrefix(source.URL, "docker://") || strings.HasPrefix(source.URL, "oci:") {
|
|
sourceType = "oci-image"
|
|
} else {
|
|
return nil, fmt.Errorf("unsupported source URL format: %s (use oci-image type)", source.URL)
|
|
}
|
|
}
|
|
|
|
switch sourceType {
|
|
case "oci-image":
|
|
return as.downloadAndDecryptOCIImage(ctx, source, resourceType)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported source type: %s", sourceType)
|
|
}
|
|
}
|
|
|
|
// downloadAndDecryptOCIImage downloads and decrypts an OCI image using Skopeo and CoCo Keyprovider.
|
|
func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) {
|
|
as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s)",
|
|
source.URL, source.Encrypted, source.KBSResourcePath))
|
|
|
|
// Create Skopeo client
|
|
if as.ociClient == nil {
|
|
return nil, fmt.Errorf("OCI client not initialized")
|
|
}
|
|
|
|
// Create OCI resource source
|
|
ociSource := oci.ResourceSource{
|
|
Type: oci.ResourceTypeOCIImage,
|
|
URI: source.URL,
|
|
Encrypted: source.Encrypted,
|
|
KBSResourcePath: source.KBSResourcePath,
|
|
}
|
|
|
|
// Pull and decrypt image
|
|
// CoCo Keyprovider will automatically handle decryption via ocicrypt
|
|
// Sanitize directory name to avoid Skopeo interpreting ':' as tag separator
|
|
sanitizedName := strings.ReplaceAll(filepath.Base(source.URL), ":", "_")
|
|
destDir := filepath.Join(os.TempDir(), "cocos-oci", "images", sanitizedName)
|
|
if err := as.ociClient.PullAndDecrypt(ctx, ociSource, destDir); err != nil {
|
|
return nil, fmt.Errorf("failed to pull and decrypt OCI image: %w", err)
|
|
}
|
|
|
|
as.logger.Info("OCI image downloaded and decrypted", "dest", destDir)
|
|
|
|
// Extract algorithm file from OCI layers
|
|
extractDir := filepath.Join(os.TempDir(), "cocos-oci", "extracted", sanitizedName)
|
|
var algorithmPath string
|
|
var err error
|
|
|
|
if resourceType == "algorithm" {
|
|
algorithmPath, err = oci.ExtractAlgorithm(ctx, as.logger, destDir, extractDir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to extract algorithm from OCI image: %w", err)
|
|
}
|
|
as.logger.Info("algorithm extracted from OCI image", "path", algorithmPath)
|
|
} else {
|
|
// Assume dataset
|
|
files, err := oci.ExtractDataset(destDir, extractDir)
|
|
if err != nil || len(files) == 0 {
|
|
return nil, fmt.Errorf("failed to extract dataset from OCI image: %w", err)
|
|
}
|
|
// For now, take the first file found.
|
|
// nolint:godox // TODO: Handle multiple files / directory structure if needed.
|
|
algorithmPath = files[0]
|
|
as.logger.Info("dataset extracted from OCI image", "path", algorithmPath)
|
|
}
|
|
|
|
// Read algorithm file
|
|
algorithmData, err := os.ReadFile(algorithmPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read algorithm file: %w", err)
|
|
}
|
|
|
|
// Check for requirements.txt if algorithm
|
|
var reqData []byte
|
|
if resourceType == "algorithm" {
|
|
reqPath := filepath.Join(filepath.Dir(algorithmPath), "requirements.txt")
|
|
if data, err := os.ReadFile(reqPath); err == nil {
|
|
reqData = data
|
|
as.logger.Info("found requirements.txt", "size", len(data))
|
|
}
|
|
}
|
|
|
|
as.logger.Info("algorithm loaded", "size", len(algorithmData))
|
|
|
|
return &DecryptedResource{
|
|
Data: algorithmData,
|
|
Requirements: reqData,
|
|
SourceDir: filepath.Dir(algorithmPath),
|
|
}, nil
|
|
}
|
|
|
|
func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
|
if as.sm.GetState() != ReceivingAlgorithm {
|
|
return ErrStateNotReady
|
|
}
|
|
as.mu.Lock()
|
|
defer as.mu.Unlock()
|
|
if as.algoReceived {
|
|
return ErrAllManifestItemsReceived
|
|
}
|
|
|
|
var algoData []byte
|
|
|
|
// Check if algorithm should be downloaded from remote source
|
|
if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled {
|
|
as.logger.Info("downloading algorithm from remote source")
|
|
|
|
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to download and decrypt algorithm: %w", err)
|
|
}
|
|
|
|
algoData = res.Data
|
|
as.algoRequirements = res.Requirements
|
|
} else {
|
|
// Use directly uploaded algorithm
|
|
algoData = algo.Algorithm
|
|
}
|
|
|
|
hash := sha3.Sum256(algoData)
|
|
|
|
if hash != as.computation.Algorithm.Hash {
|
|
return ErrHashMismatch
|
|
}
|
|
|
|
currentDir, err := os.Getwd()
|
|
if err != nil {
|
|
return fmt.Errorf("error getting current directory: %v", err)
|
|
}
|
|
|
|
f, err := os.Create(filepath.Join(currentDir, "algo"))
|
|
if err != nil {
|
|
return fmt.Errorf("error creating algorithm file: %v", err)
|
|
}
|
|
|
|
if _, err := f.Write(algoData); err != nil {
|
|
return fmt.Errorf("error writing algorithm to file: %v", err)
|
|
}
|
|
|
|
if err := os.Chmod(f.Name(), algoFilePermission); err != nil {
|
|
return fmt.Errorf("error changing file permissions: %v", err)
|
|
}
|
|
|
|
if err := f.Close(); err != nil {
|
|
return fmt.Errorf("error closing file: %v", err)
|
|
}
|
|
|
|
algoType := algorithm.AlgorithmTypeFromContext(ctx)
|
|
if algoType == "" {
|
|
algoType = string(algorithm.AlgoTypeBin)
|
|
}
|
|
|
|
args := algorithm.AlgorithmArgsFromContext(ctx)
|
|
|
|
as.algoType = algoType
|
|
as.algoArgs = args
|
|
as.algoRequirements = algo.Requirements
|
|
as.algoReceived = true
|
|
|
|
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
|
|
return fmt.Errorf("error creating datasets directory: %v", err)
|
|
}
|
|
|
|
if as.algoReceived {
|
|
as.sm.SendEvent(AlgorithmReceived)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
|
|
if as.sm.GetState() != ReceivingData {
|
|
return ErrStateNotReady
|
|
}
|
|
as.mu.Lock()
|
|
defer as.mu.Unlock()
|
|
if len(as.computation.Datasets) == 0 {
|
|
return ErrAllManifestItemsReceived
|
|
}
|
|
|
|
var datasetData []byte
|
|
var datasetFilename string
|
|
|
|
// Check if any dataset should be downloaded from remote source
|
|
matchedIndex := -1
|
|
for i, d := range as.computation.Datasets {
|
|
if d.Source != nil && as.computation.KBS.Enabled {
|
|
as.logger.Info("downloading dataset from remote source", "filename", d.Filename)
|
|
|
|
downloadedData, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to download and decrypt dataset: %w", err)
|
|
}
|
|
|
|
datasetData = downloadedData.Data
|
|
datasetFilename = d.Filename
|
|
matchedIndex = i
|
|
break
|
|
}
|
|
}
|
|
|
|
// If no remote dataset, use uploaded dataset
|
|
if matchedIndex == -1 {
|
|
datasetData = dataset.Dataset
|
|
datasetFilename = dataset.Filename
|
|
}
|
|
|
|
hash := sha3.Sum256(datasetData)
|
|
|
|
matched := false
|
|
for i, d := range as.computation.Datasets {
|
|
if hash == d.Hash {
|
|
if d.Filename != "" && d.Filename != datasetFilename {
|
|
return ErrFileNameMismatch
|
|
}
|
|
|
|
as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1)
|
|
|
|
if DecompressFromContext(ctx) {
|
|
if err := internal.UnzipFromMemory(datasetData, algorithm.DatasetsDir); err != nil {
|
|
return fmt.Errorf("error decompressing dataset: %v", err)
|
|
}
|
|
} else {
|
|
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, datasetFilename))
|
|
if err != nil {
|
|
return fmt.Errorf("error creating dataset file: %v", err)
|
|
}
|
|
|
|
if _, err := f.Write(datasetData); err != nil {
|
|
return fmt.Errorf("error writing dataset to file: %v", err)
|
|
}
|
|
if err := f.Close(); err != nil {
|
|
return fmt.Errorf("error closing file: %v", err)
|
|
}
|
|
}
|
|
|
|
matched = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !matched {
|
|
return ErrUndeclaredDataset
|
|
}
|
|
|
|
if len(as.computation.Datasets) == 0 {
|
|
defer as.sm.SendEvent(DataReceived)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (as *agentService) Result(ctx context.Context) ([]byte, error) {
|
|
currentState := as.sm.GetState()
|
|
if currentState != ConsumingResults && currentState != Complete && currentState != Failed {
|
|
return []byte{}, ErrResultsNotReady
|
|
}
|
|
|
|
index, ok := IndexFromContext(ctx)
|
|
if !ok {
|
|
return []byte{}, ErrUndeclaredConsumer
|
|
}
|
|
|
|
as.mu.Lock()
|
|
defer as.mu.Unlock()
|
|
if index < 0 || index >= len(as.computation.ResultConsumers) {
|
|
return []byte{}, ErrUndeclaredConsumer
|
|
}
|
|
|
|
if !as.resultsConsumed && currentState == ConsumingResults {
|
|
as.resultsConsumed = true
|
|
defer as.sm.SendEvent(ResultsConsumed)
|
|
}
|
|
|
|
return as.result, as.runError
|
|
}
|
|
|
|
func (as *agentService) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
|
|
rawQuote, err := as.attestationClient.GetAttestation(ctx, reportData, nonce, attType)
|
|
if err != nil {
|
|
return []byte{}, errors.Wrap(ErrAttestationFailed, err)
|
|
}
|
|
return rawQuote, nil
|
|
}
|
|
|
|
func (as *agentService) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error) {
|
|
token, err := as.attestationClient.GetAzureToken(ctx, nonce)
|
|
if err != nil {
|
|
return []byte{}, err
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
func (as *agentService) runComputation(state statemachine.State) {
|
|
as.publishEvent(Starting.String())(state)
|
|
as.logger.Debug("computation run started")
|
|
defer func() {
|
|
if as.runError != nil {
|
|
as.sm.SendEvent(RunFailed)
|
|
} else {
|
|
as.sm.SendEvent(RunComplete)
|
|
}
|
|
}()
|
|
|
|
if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil {
|
|
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
|
|
as.logger.Warn(as.runError.Error())
|
|
as.publishEvent(Failed.String())(state)
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
|
|
as.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error()))
|
|
}
|
|
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
|
|
as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
|
|
}
|
|
}()
|
|
|
|
// Read algo file
|
|
currentDir, _ := os.Getwd()
|
|
algoFile := filepath.Join(currentDir, "algo")
|
|
algoBytes, err := os.ReadFile(algoFile)
|
|
if err != nil {
|
|
as.runError = fmt.Errorf("failed to read algo file: %w", err)
|
|
as.logger.Warn(as.runError.Error())
|
|
as.publishEvent(Failed.String())(state)
|
|
return
|
|
}
|
|
|
|
as.publishEvent(InProgress.String())(state)
|
|
|
|
// Call Runner
|
|
resp, err := as.runnerClient.Run(context.Background(), &runnerpb.RunRequest{
|
|
ComputationId: as.computation.ID,
|
|
AlgoType: as.algoType,
|
|
Algorithm: algoBytes,
|
|
Requirements: as.algoRequirements,
|
|
Args: as.algoArgs,
|
|
// Datasets implicit on shared FS
|
|
})
|
|
if err != nil {
|
|
as.runError = err
|
|
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
|
|
as.publishEvent(Failed.String())(state)
|
|
return
|
|
}
|
|
|
|
if resp.Error != "" {
|
|
as.runError = errors.New(resp.Error)
|
|
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error))
|
|
as.publishEvent(Failed.String())(state)
|
|
return
|
|
}
|
|
|
|
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
|
|
if err != nil {
|
|
as.runError = err
|
|
as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
|
|
as.publishEvent(Failed.String())(state)
|
|
return
|
|
}
|
|
|
|
as.publishEvent(Completed.String())(state)
|
|
|
|
as.result = results
|
|
}
|
|
|
|
func (as *agentService) publishEvent(status string) statemachine.Action {
|
|
return func(state statemachine.State) {
|
|
as.eventSvc.SendEvent(as.computation.ID, state.String(), status, json.RawMessage{})
|
|
}
|
|
}
|
|
|
|
func (as *agentService) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) {
|
|
data, err := os.ReadFile(ImaMeasurementsFilePath)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error reading Linux IMA measurements file: %s", err.Error())
|
|
}
|
|
|
|
pcr10, err := vtpm.GetPCRSHA1Value(ImaPcrIndex)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error reading TPM PCR #10: %s", err.Error())
|
|
}
|
|
|
|
return data, pcr10, nil
|
|
}
|