mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Agent Pull mode for remote resources (#575)
CI / checkproto (push) Has been cancelled
CI / lint (push) Has been cancelled
Rust CI Pipeline / rust-check (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
CI / checkproto (push) Has been cancelled
CI / lint (push) Has been cancelled
Rust CI Pipeline / rust-check (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* feat(kbs): implement KBS client for attestation and resource retrieval - Added KBS client implementation in pkg/kbs/client.go with methods for attestation and resource retrieval. - Introduced necessary data structures for requests and responses. - Implemented error handling for various scenarios. test(kbs): add unit tests for KBS client - Created comprehensive tests for the KBS client in pkg/kbs/client_test.go. - Included tests for attestation success and failure cases, as well as resource retrieval. feat(registry): introduce HTTP and S3 registry implementations - Added HTTPRegistry for downloading resources over HTTP/HTTPS with retry logic in pkg/registry/http.go. - Implemented S3Registry for downloading resources from AWS S3 and S3-compatible services in pkg/registry/s3.go. - Included error handling and configuration options for both registries. chore(registry): define registry interface and configuration - Created registry interface and configuration struct in pkg/registry/registry.go. - Added default configuration settings for registry clients. docs(cvms): update README for CVMS server configuration and usage - Enhanced documentation for CVMS server with detailed command-line flags and usage examples. - Clarified direct upload and remote resource modes, including KBS integration. fix(cvms): integrate KBS for remote resource handling in main.go - Updated main.go to support remote datasets and algorithms using KBS. - Added validation for command-line flags to ensure proper configuration. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Move ifeq conditional outside define block in attestation-service.mk Make conditionals cannot be evaluated inside define...endef blocks when used as recipe bodies. Restructured to define the ATTESTATION_SERVICE_INSTALL_INIT_SYSTEMD block conditionally based on BR2_PACKAGE_CC_ATTESTATION_AGENT configuration. * feat: Implement remote resource downloading for algorithms and datasets using AWS S3/MinIO credentials. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add comprehensive documentation and agent support for testing remote resource download with KBS attestation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Improve agent logging for remote resource configuration and KBS status, and add a testing guide for remote resource downloads with KBS attestation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add a comprehensive guide for testing remote resource download with KBS attestation and update multiple package versions to a specific commit. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add failure transitions for resource reception states and a comprehensive guide for testing remote resource downloads with KBS attestation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Implement remote resource download with KBS attestation in the agent and add a comprehensive testing guide. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: Add comprehensive guide for testing remote resource download with KBS attestation and include a debug log in the attestation client. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Delegate KBS attestation and token retrieval to a new attestation-agent service and document remote resource testing. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * client fixes Signed-off-by: Sammy Oina <sammyoina@gmail.com> * raw evidence Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Build all Go files in cmd directories, not just main.go This fixes the issue where fetch_raw_evidence.go wasn't being included in the attestation-service build. * fix: Wrap binary evidence in JSON for KBS compatibility Fixes 'invalid character' error by wrapping raw binary evidence in a JSON structure with base64 encoding, as expected by KBS. * chore: Update buildroot packages toc28cefaeIncludes fixes for: 1. attestation-service build (including fetch_raw_evidence.go) 2. Agent KBS evidence format (wrapping binary in JSON) * fix: Implement KBS RCAR handshake with cookies Fixes 'cookie not found' error (401) from KBS by: 1. Adding CookieJar support to KBS client 2. Implementing GetChallenge() to perform /auth handshake and capture session cookie 3. Updating Agent to get challenge, decode nonce, and use it for evidence generation 4. Regenerating mocks * chore: Update buildroot packages tof6981ac5Includes KBS RCAR handshake fix (cookie support + GetChallenge loop) * fix: Update KBS client JSON tags to kebab-case Fixes deserialization error (401) from KBS by: 1. Using kebab-case (e.g. extra-params) for JSON tags as per protocol. 2. Initializing ExtraParams as empty object {} instead of null/omitted. * fix: Wrap attestation evidence in primary_evidence format Updates Agent to construct 'tee-evidence' payload with: - primary_evidence: containing the actual quote/data - additional_evidence: empty JSON object This matches the Confidential Containers KBS Attestation Protocol requirements. * fix: Update KBS protocol version to 0.4.0 KBS rejected 0.1.0 with a version mismatch error. Bumping to 0.4.0 to match server expectation. * fix: Generate ephemeral key for KBS RuntimeData Updates RuntimeData to include a valid ephemeral EC P-256 public key in JWK format, as required by the KBS RCAR protocol. Also fixes the KBS client struct to support TEEPubKey as an object. * fix: Update sample attestation quote to valid JSON The default attestation.bin was binary, but the KBS Sample Verifier expects a valid JSON quote containing 'svn' and 'report_data'. Updated the embedded bin file to contain this JSON structure. * fix: Generate dynamic JSON quote for Sample TEE in FetchRawEvidence The KBS Sample Verifier expects a JSON object with 'svn' and 'report_data'. Previously, we were returning raw binary data (reportData+nonce). This commit updates FetchRawEvidence to return a marshaled JSON structure with: - svn: "1" - report_data: base64(req.ReportData) * refactor: Delegate Sample Attestation to Provider Refactored sample attestation logic: - Moved JSON Quote generation into EmptyProvider (standalone mode). - Updated FetchRawEvidence to call provider.TeeAttestation instead of manual generation. This enables using the real CC Attestation Agent for UNSPECIFIED platform if configured. * feat: Add comprehensive debug logging and enforce CC AA usage Changes: - Updated EmptyProvider to return error instead of generating mock data This forces proper use of CC Attestation Agent's sample attester - Added detailed logging to attestation-service FetchRawEvidence: * Hex dump of evidence (first 200 bytes) * String preview of evidence * Total evidence length - Added detailed logging to agent service: * Raw evidence hex and string previews * KBS evidence JSON preview (first 500 bytes) * Evidence lengths at each transformation step This logging will help diagnose why KBS Sample Verifier is rejecting evidence. * fix: Enable CC AA by default and add attestation-service log forwarding Changes: - Set USE_CC_ATTESTATION_AGENT=true by default in systemd service - Added StandardOutput/StandardError to forward logs to /var/log/cocos/ - Updated HAL makefile to handle new default value - This ensures attestation-service uses CC AA's sample attester - Logs will now be visible in CVMS output for debugging * feat: Add gRPC log forwarding to attestation-service Implemented the same log forwarding mechanism used by the agent: - Added ProtoHandler to write logs to both stdout and logQueue - Connected to log client (/run/cocos/log.sock) for gRPC forwarding - Added goroutine to forward logs to CVMS via log client - Logs will now appear in CVMS output during computation runs This enables visibility into attestation-service debug output including: - CC AA connection status - Evidence generation details (hex dumps, string previews) - Any errors from providers * fix: Parse sample evidence JSON instead of base64-encoding it The attestation-service returns sample evidence as JSON: {"svn":"1","report_data":"base64..."} The agent was incorrectly base64-encoding this JSON string again. KBS Sample Verifier expects the parsed JSON object directly. Fixed by: - Parsing the JSON evidence from attestation-service - Passing the parsed object directly in primary_evidence.evidence - This matches what KBS Sample Verifier expects * debug: Increase KBS evidence logging preview to 1000 bytes Show the complete JSON structure being sent to KBS to debug the attestation failure. * debug: Add comprehensive CC AA configuration logging Added debug logs to show: - Whether CC AA is enabled in config - CC AA address being used - Connection success/failure - Which provider is ultimately selected - Warning when falling back to EmptyProvider This will help diagnose why EmptyProvider is being used instead of CC Attestation Agent. * debug: Add startup logging for log client connection Added log message to show if log client connection succeeds at attestation-service startup. This will help diagnose why logs aren't appearing in CVMS output. * feat: Add retry logic with exponential backoff to log client Added simple retry mechanism to handle concurrent log requests: - 3 retry attempts with exponential backoff (10ms, 20ms, 40ms) - Applies to both SendLog and SendEvent methods - Centralized in log client so all services benefit - Should eliminate 'failed to send log' errors from concurrent requests This fixes the issue where attestation-service logs weren't appearing in CVMS output due to dropped messages. * fix: Flatten sample evidence fields in primary_evidence for KBS KBS Sample Verifier expects svn and report_data at the top level of primary_evidence, not nested under an 'evidence' key. Changed structure from: {"primary_evidence": {"tee": "sample", "evidence": {"svn": "1", ...}}} To: {"primary_evidence": {"tee": "sample", "svn": "1", "report_data": "...", ...}} This matches what KBS expects when deserializing the Quote structure. * fix: Use sample quote directly as primary_evidence per KBS protocol According to KBS attestation protocol spec, for sample TEE type, primary_evidence should be the sample quote JSON directly: {"svn": "1", "report_data": "..."} Removed extra 'tee' and 'platform' fields that were causing KBS to fail deserializing the Quote structure. The 'tee' field is already sent in the Request payload during RCAR handshake. Refs: - https://github.com/confidential-containers/trustee/blob/main/kbs/docs/kbs_attestation_protocol.md - https://github.com/confidential-containers/guest-components/blob/main/attestation-agent/attester/src/sample/mod.rs * fix: Make CC AA required for sample attestation when configured When USE_CC_ATTESTATION_AGENT=true, attestation-service now requires AA to be available for NoCC/sample platform. This ensures sample evidence always comes from AA with the correct KBS format. Changes: - Error out if AA connection fails for NoCC platform when AA is configured - Only use EmptyProvider if AA is explicitly NOT configured - Prevents incorrect sample evidence format from EmptyProvider This ensures attestation-service delegates to AA for sample evidence generation instead of creating it itself. * fix: Implement proper RCAR protocol with tee-pubkey and runtime-data hash Fixed KBS attestation error 'REPORT_DATA is different from that in Sample Quote' Changes: 1. Generate ephemeral EC key pair BEFORE getting evidence from AA 2. Create runtime-data with nonce + tee-pubkey (JWK format) 3. Hash runtime-data (SHA-256) and use as report_data for AA 4. This binds the tee-pubkey to the TEE evidence per RCAR protocol The report_data in the evidence now matches what KBS expects: hash(runtime-data) instead of computation ID. This completes the full RCAR protocol implementation: - Request → Challenge → Attestation (with bound tee-pubkey) → Response * fix(agent): use simple nonce for Sample attestation report_data For Sample/NoCC attestation, use the raw nonce bytes directly as report_data instead of hashing runtime-data. This avoids JSON serialization mismatches with the KBS Sample verifier. Real TEEs (TDX/SNP) still use runtime-data hash binding to cryptographically bind the ephemeral tee-pubkey to the evidence. * fix(agent): use RFC 8785 canonical JSON for runtime-data hashing The KBS Sample attestation verifier (and likely others) expects the report_data to be the SHA-256 hash of the *canonical* JSON serialization (RFC 8785) of the runtime-data. Standard Go JSON marshaling does not guarantee key ordering, leading to hash mismatches. This change uses github.com/gowebpki/jcs to canonicalize the runtime-data before hashing, ensuring compatibility with the KBS RCAR implementation. Also reverted the temporary 'simple nonce' workaround. * feat(hal): add CoCo Keyprovider and Skopeo packages - Add coco-keyprovider buildroot package with systemd service - Add skopeo buildroot package for OCI image handling - Add ocicrypt_keyprovider.conf for encrypted image decryption - Update Config.in to include new packages This enables standard CoCo ecosystem integration for encrypted OCI images instead of custom S3/HTTP registry clients. * feat(oci): add OCI image handling package with Skopeo integration - Add pkg/oci/types.go with ResourceSource and ImageManifest types - Add pkg/oci/skopeo.go with Skopeo wrapper for pull/decrypt - Add pkg/oci/extract.go for extracting algorithms and datasets from layers This package provides OCI image handling using Skopeo and CoCo Keyprovider for encrypted image decryption, replacing custom S3/HTTP registry clients. * chore: regenerate protobuf files for updated cvms.proto * refactor(agent): replace S3/HTTP/KBS with OCI package - Remove pkg/kbs and pkg/registry imports - Add pkg/oci import for OCI image handling - Replace downloadAndDecryptResource with OCI-based implementation - Use Skopeo + CoCo Keyprovider for automatic decryption - Reduce code from ~240 lines to ~70 lines This eliminates custom KBS RCAR handshake, S3/HTTP registry clients, and manual decryption logic. CoCo Keyprovider handles all decryption automatically via ocicrypt protocol. * chore: remove obsolete pkg/kbs and pkg/registry packages - Delete pkg/kbs/ (custom KBS client, ~300 lines) - Delete pkg/registry/ (S3/HTTP registry clients, ~400 lines) - Remove unused imports from agent/service.go - Run go mod tidy to clean up dependencies These packages have been replaced by pkg/oci with Skopeo and CoCo Keyprovider for standard CoCo ecosystem integration. * fix(agent): update ResourceSource struct to include type and encryption fields Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix(hal): update CoCo Keyprovider to v0.16.0 and fix build path - Update version from v0.11.0 to v0.16.0 (matches attestation agent) - Fix install path: target is at repo root, not in coco_keyprovider subdir - This fixes the build error where coco_keyprovider binary wasn't found The cargo workspace in guest-components builds to a shared target/ directory at the repository root, not within each crate's subdirectory. * feat: Update remote resources testing guide to use kbs-client and coco-keyprovider for key management and encryption, enable insecure TLS for Skopeo, and enhance CVMS with Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update component versions, revise image encryption documentation, and sanitize OCI image paths for Skopeo compatibility. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add `decompress` option to Dataset and `algo_type`/`algo_args` to Algorithm protobuf messages, updating client, test, and build configurations. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Update multiple package versions and enhance OCI image extraction error reporting for missing algorithm files. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: Bump package versions, improve OCI image extraction debugging by returning seen files, and remove unused dataset type parsing from test code. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Migrate OCI extraction to use structured logging with `slog` and `context`, and update package versions. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Bump multiple component versions, add encrypted status for computation inputs and algorithms, and refine OCI layer extraction warnings. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * logging Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add `Encrypted` field to algorithm and dataset resource sources and update all component versions. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: update component versions, integrate coco-keyprovider service, and configure ocicrypt key provider. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: add support for KBS parameters and dataset/algorithm hash calculations in CVMS Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: update resource download and extraction logic to support requirements.txt and improve hash verification Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: Update dependencies, improve code style, and add GetRawEvidence to attestation client mocks. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor code structure for improved readability and maintainability Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: update golangci configuration to include errcheck for build path and remove unnecessary exclusions Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: streamline kernel command line handling in QEMU args construction Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: add attestation binary and update checksum tests and policy structure Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for attestation agent, attestation, log, crypto, OCI, and Skopeo clients - Implement tests for the attestation agent client including Unix socket and TCP address handling, token retrieval, and error scenarios. - Enhance attestation client tests to cover fetching raw evidence for various platforms (SNP, TDX, VTPM, SNPvTPM) and validate error handling. - Introduce log client tests to verify retry behavior for sending logs and events. - Create comprehensive tests for crypto package focusing on AES-GCM decryption, encrypted resource parsing, and key unwrapping. - Add tests for OCI package to validate algorithm and dataset extraction, including JSON serialization of OCILayout. - Implement Skopeo client tests to ensure proper functionality for image pulling, inspecting, and resource source handling. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: handle JSON marshal errors in test cases for decrypt and extract functions Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add comprehensive tests for algorithm and dataset extraction with various scenarios Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: replace hardcoded Python script content with constant variable Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: remove redundant mock expectation for SendAgentConfig in TestCreateVMWithAaKbsParams Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add tests for event sending failure, dataset extraction with path traversal, and Skopeo client behavior Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add tests for download and decryption of resources with various URL formats Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Introduce OCIClient interface for agent service to improve testability of OCI image operations and enhance related tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Change `get_uint64_from_tcb` to accept `TcbVersion` by value and use `u64::from` for type conversions. --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f77ec5644a
commit
da31d76c94
@@ -24,6 +24,21 @@ The service is configured using the environment variables from the following tab
|
||||
| AGENT_OS_BUILD | Operating system build information for attestation | UVC |
|
||||
| AGENT_OS_DISTRO | Operating system distribution information for attestation | UVC |
|
||||
| AGENT_OS_TYPE | Operating system type information for attestation | UVC |
|
||||
| ATTESTATION_SERVICE_SOCKET | Unix socket path for attestation service communication | /run/cocos/attestation.sock |
|
||||
| AGENT_ENABLE_ATLS | Enable Attestation TLS for secure communication | true |
|
||||
|
||||
### Remote Resource Download (Optional)
|
||||
|
||||
The agent supports downloading encrypted algorithms and datasets from remote registries (S3, HTTP/HTTPS) and retrieving decryption keys from a Key Broker Service (KBS) via attestation.
|
||||
|
||||
| Variable | Description | Default |
|
||||
| ------------------------------ | ------------------------------------------------------------------------------------------------------------- | ----------------------------------------------- |
|
||||
| AWS_REGION | AWS region for S3 access (required for S3 downloads) | \"\" |
|
||||
| AWS_ACCESS_KEY_ID | AWS access key ID for S3 authentication | \"\" |
|
||||
| AWS_SECRET_ACCESS_KEY | AWS secret access key for S3 authentication | \"\" |
|
||||
| AWS_ENDPOINT_URL | Custom S3 endpoint URL (for S3-compatible services like MinIO) | \"\" |
|
||||
|
||||
**Note**: KBS URL is specified in the computation manifest, not as an environment variable. See [TESTING_REMOTE_RESOURCES.md](./TESTING_REMOTE_RESOURCES.md) for details on using remote resources.
|
||||
|
||||
## Deployment
|
||||
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
# Testing Remote Resources with CoCo Key Provider
|
||||
|
||||
This guide explains how to test Cocos with encrypted remote resources using the Confidential Containers Key Provider ecosystem.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ CVM (Agent) │
|
||||
│ │
|
||||
│ ┌──────────┐ ┌────────────────┐ ┌─────────────────┐ │
|
||||
│ │ Agent │───▶│ Skopeo │───▶│ CoCo Keyprovider│ │
|
||||
│ └──────────┘ │ (ocicrypt) │ │ (gRPC:50011) │ │
|
||||
│ └────────────────┘ └────────┬────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────▼────────┐ │
|
||||
│ │ Attestation │ │
|
||||
│ │ Agent (50002) │ │
|
||||
│ └────────┬────────┘ │
|
||||
└──────────────────────────────────────────────────┼──────────┘
|
||||
│
|
||||
┌────────▼────────┐
|
||||
│ KBS Server │
|
||||
│ (Host:8080) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Install Skopeo (Host Machine)
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install skopeo
|
||||
|
||||
# macOS
|
||||
brew install skopeo
|
||||
|
||||
# Or build from source
|
||||
git clone https://github.com/containers/skopeo
|
||||
cd skopeo
|
||||
make bin/skopeo
|
||||
sudo make install
|
||||
```
|
||||
|
||||
### 2. Start KBS Server (Host Machine)
|
||||
|
||||
```bash
|
||||
# Clone and build KBS
|
||||
git clone https://github.com/confidential-containers/trustee
|
||||
cd trustee/kbs
|
||||
# Patch Cargo.toml to disable SGX requirement (for testing only)
|
||||
sed -i 's/"all-verifier",//g' Cargo.toml
|
||||
|
||||
make
|
||||
make cli
|
||||
|
||||
# Generate admin keys
|
||||
openssl genpkey -algorithm ed25519 -out kbs-admin.key
|
||||
openssl pkey -in kbs-admin.key -pubout -out kbs-admin.pub
|
||||
|
||||
# Create KBS configuration file
|
||||
cat > kbs-config.toml << 'EOF'
|
||||
[http_server]
|
||||
sockets = ["0.0.0.0:8080"]
|
||||
insecure_http = true
|
||||
|
||||
[admin]
|
||||
type = "Simple"
|
||||
[[admin.personas]]
|
||||
id = "admin"
|
||||
public_key_path = "kbs-admin.pub"
|
||||
|
||||
[attestation_service]
|
||||
type = "coco_as_builtin"
|
||||
work_dir = "kbs-data/as"
|
||||
|
||||
[attestation_service.rvps_config]
|
||||
type = "BuiltIn"
|
||||
|
||||
[attestation_service.rvps_config.storage]
|
||||
type = "LocalFs"
|
||||
file_path = "kbs-data/rvps-values"
|
||||
|
||||
[[plugins]]
|
||||
name = "resource"
|
||||
type = "LocalFs"
|
||||
dir_path = "kbs-data/repository"
|
||||
EOF
|
||||
|
||||
# Create configuration directories
|
||||
mkdir -p kbs-data/as kbs-data/rvps kbs-data/repository
|
||||
|
||||
# Start KBS
|
||||
../target/release/kbs --config-file kbs-config.toml
|
||||
```
|
||||
|
||||
KBS will listen on `http://localhost:8080`
|
||||
|
||||
### 3. Setup Local OCI Registry (Optional)
|
||||
|
||||
For testing, you can use a local registry:
|
||||
|
||||
```bash
|
||||
docker run -d -p 5000:5000 --name registry registry:2
|
||||
```
|
||||
|
||||
## Creating Encrypted Resources
|
||||
|
||||
### Encrypt an Algorithm (Python Script)
|
||||
|
||||
```bash
|
||||
# 1. Create a simple algorithm
|
||||
cat > lin_reg.py << 'EOF'
|
||||
import pandas as pd
|
||||
from sklearn.linear_model import LinearRegression
|
||||
import sys
|
||||
|
||||
# Load dataset
|
||||
data = pd.read_csv(sys.argv[1])
|
||||
X = data[['feature1', 'feature2']]
|
||||
y = data['target']
|
||||
|
||||
# Train model
|
||||
model = LinearRegression()
|
||||
model.fit(X, y)
|
||||
|
||||
# Save results
|
||||
print(f"Coefficients: {model.coef_}")
|
||||
print(f"Intercept: {model.intercept_}")
|
||||
EOF
|
||||
|
||||
# 2. Create a Dockerfile
|
||||
cat > Dockerfile << 'EOF'
|
||||
FROM python:3.9-slim
|
||||
RUN pip install pandas scikit-learn
|
||||
COPY lin_reg.py /app/algorithm.py
|
||||
WORKDIR /app
|
||||
ENTRYPOINT ["python", "algorithm.py"]
|
||||
EOF
|
||||
|
||||
# 3. Build the image
|
||||
docker build -t localhost:5000/lin-reg-algo:v1.0 .
|
||||
docker push localhost:5000/lin-reg-algo:v1.0
|
||||
|
||||
# 4. Generate and store key
|
||||
openssl rand -out algo.key 32
|
||||
|
||||
# 5. Store key in KBS using kbs-client
|
||||
../target/release/kbs-client --url http://localhost:8080 config \
|
||||
--auth-private-key kbs-admin.key \
|
||||
set-resource \
|
||||
--path default/key/algo-key \
|
||||
--resource-file algo.key
|
||||
|
||||
# 6. Encrypt the image using Host Skopeo + Docker Keyprovider
|
||||
# Start Keyprovider in background
|
||||
docker run -d --rm --name keyprovider --network host \
|
||||
-v "$PWD:/work" -w /work \
|
||||
ghcr.io/confidential-containers/staged-images/coco-keyprovider:latest \
|
||||
coco_keyprovider --socket 127.0.0.1:50000
|
||||
|
||||
# Configure Ocicrypt to use local Keyprovider
|
||||
cat <<EOF > ocicrypt.conf
|
||||
{
|
||||
"key-providers": {
|
||||
"attestation-agent": {
|
||||
"grpc": "127.0.0.1:50000"
|
||||
}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
export OCICRYPT_KEYPROVIDER_CONFIG=$(pwd)/ocicrypt.conf
|
||||
|
||||
# Encrypt Algo
|
||||
skopeo copy \
|
||||
--src-tls-verify=false \
|
||||
--dest-tls-verify=false \
|
||||
--encryption-key "provider:attestation-agent:keypath=/work/algo.key::keyid=kbs:///default/key/algo-key::algorithm=A256GCM" \
|
||||
docker://localhost:5000/lin-reg-algo:v1.0 \
|
||||
docker://localhost:5000/encrypted-lin-reg:v1.0
|
||||
|
||||
# Stop Keyprovider
|
||||
docker stop keyprovider
|
||||
```
|
||||
|
||||
### Encrypt a Dataset (CSV in OCI Image)
|
||||
|
||||
```bash
|
||||
# 1. Create dataset
|
||||
cat > iris.csv << 'EOF'
|
||||
feature1,feature2,target
|
||||
5.1,3.5,0
|
||||
4.9,3.0,0
|
||||
6.2,3.4,1
|
||||
5.9,3.0,1
|
||||
EOF
|
||||
|
||||
# 2. Create Dockerfile for dataset
|
||||
cat > Dockerfile.dataset << 'EOF'
|
||||
FROM scratch
|
||||
COPY iris.csv /data/iris.csv
|
||||
EOF
|
||||
|
||||
# 3. Build and push
|
||||
docker build -f Dockerfile.dataset -t localhost:5000/iris-dataset:v1.0 .
|
||||
docker push localhost:5000/iris-dataset:v1.0
|
||||
|
||||
# 4. Generate and store key
|
||||
# 4. Generate and store key
|
||||
openssl rand -out dataset.key 32
|
||||
../target/release/kbs-client --url http://localhost:8080 config \
|
||||
--auth-private-key kbs-admin.key \
|
||||
set-resource \
|
||||
--path default/key/dataset-key \
|
||||
--resource-file dataset.key
|
||||
|
||||
# 5. Encrypt dataset image using Host Skopeo + Docker Keyprovider
|
||||
# Start Keyprovider in background
|
||||
docker run -d --rm --name keyprovider --network host \
|
||||
-v "$PWD:/work" -w /work \
|
||||
ghcr.io/confidential-containers/staged-images/coco-keyprovider:latest \
|
||||
coco_keyprovider --socket 127.0.0.1:50000
|
||||
|
||||
# Configure Ocicrypt (if not already done)
|
||||
export OCICRYPT_KEYPROVIDER_CONFIG=$(pwd)/ocicrypt.conf
|
||||
|
||||
# Encrypt Dataset
|
||||
skopeo copy \
|
||||
--src-tls-verify=false \
|
||||
--dest-tls-verify=false \
|
||||
--encryption-key "provider:attestation-agent:keypath=/work/dataset.key::keyid=kbs:///default/key/dataset-key::algorithm=A256GCM" \
|
||||
docker://localhost:5000/iris-dataset:v1.0 \
|
||||
docker://localhost:5000/encrypted-iris:v1.0
|
||||
|
||||
# Stop Keyprovider
|
||||
docker stop keyprovider
|
||||
```
|
||||
|
||||
## Running a Computation
|
||||
|
||||
### 1. Start Manager (Host)
|
||||
|
||||
```bash
|
||||
cd /path/to/cocos-ai
|
||||
./build/cocos-manager
|
||||
```
|
||||
|
||||
### 2. Start CVMS Test Server (Host)
|
||||
|
||||
Get your host IP:
|
||||
```bash
|
||||
HOST_IP=$(ip -4 addr show | grep -oP '(?<=inet\s)\d+(\.\d+){3}' | grep -v 127.0.0.1 | head -n1)
|
||||
```
|
||||
|
||||
Start CVMS server:
|
||||
```bash
|
||||
# Calculate SHA3-256 of decrypted files using cocos-cli
|
||||
# NOTE: We use the hash of the original plaintext files, as the Agent validates the decrypted content.
|
||||
# Redirect stderr to stdout (2>&1) because cocos-cli prints to stderr
|
||||
ALGO_HASH=$(./build/cocos-cli checksum lin_reg.py 2>&1 | awk '{print $NF}')
|
||||
DATASET_HASH=$(./build/cocos-cli checksum iris.csv 2>&1 | awk '{print $NF}')
|
||||
|
||||
go build -o build/cvms-test ./test/cvms/main.go
|
||||
HOST=$HOST_IP PORT=7001 ./build/cvms-test \
|
||||
-public-key-path ./public.pem \
|
||||
-attested-tls-bool false \
|
||||
-kbs-url http://$HOST_IP:8080 \
|
||||
-algo-type oci-image \
|
||||
-algo-source-url docker://$HOST_IP:5000/encrypted-lin-reg:v1.0 \
|
||||
-algo-kbs-path default/key/algo-key \
|
||||
-algo-hash $ALGO_HASH \
|
||||
-dataset-type oci-image \
|
||||
-dataset-source-urls docker://$HOST_IP:5000/encrypted-iris:v1.0 \
|
||||
-dataset-kbs-paths default/key/dataset-key \
|
||||
-dataset-hash $DATASET_HASH
|
||||
```
|
||||
|
||||
### 3. Create VM via CLI (Host)
|
||||
|
||||
```bash
|
||||
export MANAGER_GRPC_URL=localhost:7002
|
||||
./build/cocos-cli create-vm \
|
||||
--server-url $HOST_IP:7001 \
|
||||
--log-level debug
|
||||
```
|
||||
|
||||
The agent will:
|
||||
1. Receive computation manifest from CVMS
|
||||
2. Use Skopeo to download encrypted OCI images
|
||||
3. Skopeo invokes CoCo Keyprovider via ocicrypt
|
||||
4. CoCo Keyprovider requests decryption key from KBS
|
||||
5. Attestation Agent generates TEE evidence for KBS
|
||||
6. KBS validates evidence and returns decryption key
|
||||
7. Image layers are decrypted and extracted
|
||||
8. Computation executes with decrypted algorithm and dataset
|
||||
|
||||
## Verifying the Setup
|
||||
|
||||
### Check CoCo Keyprovider Status (Inside CVM)
|
||||
|
||||
```bash
|
||||
# SSH into CVM or use console
|
||||
systemctl status coco-keyprovider
|
||||
journalctl -u coco-keyprovider -f
|
||||
```
|
||||
|
||||
### Check Attestation Agent Status
|
||||
|
||||
```bash
|
||||
systemctl status attestation-agent
|
||||
journalctl -u attestation-agent -f
|
||||
```
|
||||
|
||||
### Test Skopeo Decryption Manually
|
||||
|
||||
```bash
|
||||
# Inside CVM
|
||||
export OCICRYPT_KEYPROVIDER_CONFIG=/etc/ocicrypt_keyprovider.conf
|
||||
|
||||
skopeo copy \
|
||||
--src-tls-verify=false \
|
||||
--dest-tls-verify=false \
|
||||
--decryption-key provider:attestation-agent:cc_kbc::null \
|
||||
docker://localhost:5000/encrypted-lin-reg:v1.0 \
|
||||
oci:/tmp/decrypted-algo
|
||||
|
||||
# Verify decryption
|
||||
skopeo inspect oci:/tmp/decrypted-algo | jq -r '.LayersData[].MIMEType'
|
||||
# Should show: application/vnd.oci.image.layer.v1.tar+gzip
|
||||
```
|
||||
|
||||
## Computation Manifest Format
|
||||
|
||||
The CVMS server sends this manifest to the agent:
|
||||
|
||||
```json
|
||||
{
|
||||
"computation_id": "1",
|
||||
"algorithm": {
|
||||
"type": "oci-image",
|
||||
"uri": "docker://localhost:5000/encrypted-lin-reg:v1.0",
|
||||
"encrypted": true,
|
||||
"kbs_resource_path": "default/key/algo-key"
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"type": "oci-image",
|
||||
"uri": "docker://localhost:5000/encrypted-iris:v1.0",
|
||||
"encrypted": true,
|
||||
"kbs_resource_path": "default/key/dataset-key"
|
||||
}
|
||||
],
|
||||
"kbs_url": "http://192.168.100.15:8080"
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### CoCo Keyprovider Not Starting
|
||||
|
||||
```bash
|
||||
# Check logs
|
||||
journalctl -u coco-keyprovider -n 50
|
||||
|
||||
# Verify socket is listening
|
||||
ss -tlnp | grep 50011
|
||||
|
||||
# Check environment
|
||||
cat /etc/default/coco-keyprovider
|
||||
```
|
||||
|
||||
### Skopeo Decryption Fails
|
||||
|
||||
```bash
|
||||
# Verify ocicrypt config
|
||||
cat /etc/ocicrypt_keyprovider.conf
|
||||
|
||||
# Test keyprovider connection
|
||||
grpcurl -plaintext 127.0.0.1:50011 list
|
||||
|
||||
# Check KBS connectivity from CVM
|
||||
curl http://HOST_IP:8080/kbs/v0/auth
|
||||
```
|
||||
|
||||
### KBS Returns 401
|
||||
|
||||
```bash
|
||||
# Check KBS logs on host
|
||||
# Verify attestation evidence format
|
||||
# Ensure KBS is configured for sample attestation
|
||||
```
|
||||
|
||||
## Differences from Previous Approach
|
||||
|
||||
| Aspect | Old (Custom) | New (CoCo Standard) |
|
||||
|--------|-------------|---------------------|
|
||||
| **Download** | Custom S3/HTTP clients | Skopeo (OCI standard) |
|
||||
| **Decryption** | Custom KBS client | CoCo Keyprovider |
|
||||
| **Attestation** | Direct KBS RCAR | AA → CoCo KP → KBS |
|
||||
| **Format** | Raw encrypted files | OCI encrypted images |
|
||||
| **Complexity** | ~2000 lines custom code | Standard CoCo components |
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Standards Compliance**: Uses OCI and CoCo standards
|
||||
2. **Better Tooling**: Leverage Skopeo, Docker, Podman ecosystem
|
||||
3. **Simplified Code**: Remove custom registry/decryption logic
|
||||
4. **Proven Solution**: Battle-tested CoCo components
|
||||
5. **Docker Native**: Works with existing Docker workflows
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Encrypt your algorithms and datasets as OCI images
|
||||
- Push to your preferred OCI registry (Docker Hub, GHCR, etc.)
|
||||
- Update computation manifests to use `oci-image` type
|
||||
- Test end-to-end flow with encrypted workloads
|
||||
+34
-8
@@ -20,6 +20,26 @@ type AgentConfig struct {
|
||||
AttestedTls bool `json:"attested_tls,omitempty"`
|
||||
}
|
||||
|
||||
// ResourceSource specifies the location of a remote encrypted resource.
|
||||
type ResourceSource struct {
|
||||
// Type is the type of resource source (currently only "oci-image" is supported)
|
||||
Type string `json:"type,omitempty"`
|
||||
// URL is the location of the resource (e.g., docker://registry/repo:tag)
|
||||
URL string `json:"url,omitempty"`
|
||||
// KBSResourcePath is the path to the decryption key in KBS (e.g., "default/key/my-key")
|
||||
KBSResourcePath string `json:"kbs_resource_path,omitempty"`
|
||||
// Encrypted indicates whether the resource is encrypted and requires KBS
|
||||
Encrypted bool `json:"encrypted,omitempty"`
|
||||
}
|
||||
|
||||
// KBSConfig holds configuration for Key Broker Service.
|
||||
type KBSConfig struct {
|
||||
// URL is the KBS endpoint (e.g., "https://kbs.example.com")
|
||||
URL string `json:"url,omitempty"`
|
||||
// Enabled indicates whether to use KBS for key retrieval
|
||||
Enabled bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
type Computation struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -27,6 +47,7 @@ type Computation struct {
|
||||
Datasets Datasets `json:"datasets,omitempty"`
|
||||
Algorithm Algorithm `json:"algorithm,omitempty"`
|
||||
ResultConsumers []ResultConsumer `json:"result_consumers,omitempty"`
|
||||
KBS KBSConfig `json:"kbs,omitempty"`
|
||||
}
|
||||
|
||||
type ResultConsumer struct {
|
||||
@@ -42,19 +63,24 @@ func (d *Datasets) String() string {
|
||||
}
|
||||
|
||||
type Dataset struct {
|
||||
Dataset []byte `json:"-"`
|
||||
Hash [32]byte `json:"hash,omitempty"`
|
||||
UserKey []byte `json:"user_key,omitempty"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Dataset []byte `json:"-"`
|
||||
Hash [32]byte `json:"hash,omitempty"`
|
||||
UserKey []byte `json:"user_key,omitempty"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Source *ResourceSource `json:"source,omitempty"` // Optional remote source
|
||||
Decompress bool `json:"decompress,omitempty"`
|
||||
}
|
||||
|
||||
type Datasets []Dataset
|
||||
|
||||
type Algorithm struct {
|
||||
Algorithm []byte `json:"-"`
|
||||
Hash [32]byte `json:"hash,omitempty"`
|
||||
UserKey []byte `json:"user_key,omitempty"`
|
||||
Requirements []byte `json:"-"`
|
||||
Algorithm []byte `json:"-"`
|
||||
Hash [32]byte `json:"hash,omitempty"`
|
||||
UserKey []byte `json:"user_key,omitempty"`
|
||||
Requirements []byte `json:"-"`
|
||||
Source *ResourceSource `json:"source,omitempty"` // Optional remote source
|
||||
AlgoType string `json:"algo_type,omitempty"`
|
||||
AlgoArgs []string `json:"algo_args,omitempty"`
|
||||
}
|
||||
|
||||
type ManifestIndexKey struct{}
|
||||
|
||||
@@ -238,13 +238,34 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
Hash: [32]byte(runReq.Algorithm.Hash),
|
||||
UserKey: runReq.Algorithm.UserKey,
|
||||
}
|
||||
// Copy remote source if configured
|
||||
if runReq.Algorithm.Source != nil {
|
||||
ac.Algorithm.Source = &agent.ResourceSource{
|
||||
URL: runReq.Algorithm.Source.Url,
|
||||
KBSResourcePath: runReq.Algorithm.Source.KbsResourcePath,
|
||||
Encrypted: runReq.Algorithm.Source.Encrypted,
|
||||
}
|
||||
}
|
||||
ac.Algorithm.AlgoType = runReq.Algorithm.AlgoType
|
||||
ac.Algorithm.AlgoArgs = runReq.Algorithm.AlgoArgs
|
||||
}
|
||||
|
||||
for _, ds := range runReq.Datasets {
|
||||
ac.Datasets = append(ac.Datasets, agent.Dataset{
|
||||
Hash: [32]byte(ds.Hash),
|
||||
UserKey: ds.UserKey,
|
||||
})
|
||||
dataset := agent.Dataset{
|
||||
Hash: [32]byte(ds.Hash),
|
||||
UserKey: ds.UserKey,
|
||||
Filename: ds.Filename,
|
||||
}
|
||||
// Copy remote source if configured
|
||||
if ds.Source != nil {
|
||||
dataset.Source = &agent.ResourceSource{
|
||||
URL: ds.Source.Url,
|
||||
KBSResourcePath: ds.Source.KbsResourcePath,
|
||||
Encrypted: ds.Source.Encrypted,
|
||||
}
|
||||
}
|
||||
dataset.Decompress = ds.Decompress
|
||||
ac.Datasets = append(ac.Datasets, dataset)
|
||||
}
|
||||
|
||||
for _, rc := range runReq.ResultConsumers {
|
||||
@@ -253,6 +274,14 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
})
|
||||
}
|
||||
|
||||
// Copy KBS configuration
|
||||
if runReq.Kbs != nil {
|
||||
ac.KBS = agent.KBSConfig{
|
||||
URL: runReq.Kbs.Url,
|
||||
Enabled: runReq.Kbs.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the agent is in the correct state to initialize a new computation.
|
||||
// If the agent is already processing this computation (e.g., after a reconnection),
|
||||
// skip initialization to avoid state errors.
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage"
|
||||
servermocks "github.com/ultravioletrs/cocos/agent/cvms/server/mocks"
|
||||
@@ -513,3 +514,140 @@ func TestManagerClient_sendMessageTimeout(t *testing.T) {
|
||||
// Should complete without blocking
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestManagerClient_handleRunReqChunksWithRemoteSource tests handling run request with remote source.
|
||||
func TestManagerClient_handleRunReqChunksWithRemoteSource(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id-remote",
|
||||
Name: "test-computation",
|
||||
Description: "test description",
|
||||
Datasets: []*cvms.Dataset{
|
||||
{
|
||||
Hash: sha3.New256().Sum([]byte("test-dataset")),
|
||||
Filename: "data.csv",
|
||||
Source: &cvms.Source{
|
||||
Type: "oci-image",
|
||||
Url: "docker://registry.example.com/data:v1",
|
||||
KbsResourcePath: "default/key/data-key",
|
||||
Encrypted: true,
|
||||
},
|
||||
Decompress: true,
|
||||
},
|
||||
},
|
||||
Algorithm: &cvms.Algorithm{
|
||||
Hash: sha3.New256().Sum([]byte("test-algorithm")),
|
||||
AlgoType: "python",
|
||||
AlgoArgs: []string{"--verbose"},
|
||||
Source: &cvms.Source{
|
||||
Type: "oci-image",
|
||||
Url: "docker://registry.example.com/algo:v1",
|
||||
KbsResourcePath: "default/key/algo-key",
|
||||
Encrypted: true,
|
||||
},
|
||||
},
|
||||
Kbs: &cvms.KBSConfig{
|
||||
Url: "https://kbs.example.com:8080",
|
||||
Enabled: true,
|
||||
},
|
||||
ResultConsumers: []*cvms.ResultConsumer{
|
||||
{
|
||||
UserKey: []byte("test-consumer"),
|
||||
},
|
||||
},
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
chunk := &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{
|
||||
Id: "chunk-remote-1",
|
||||
Data: runReqBytes,
|
||||
IsLast: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("State").Return("ReceivingManifest")
|
||||
mockSvc.On("InitComputation", mock.Anything, mock.MatchedBy(func(c agent.Computation) bool {
|
||||
// Verify KBS config is passed
|
||||
if !c.KBS.Enabled || c.KBS.URL != "https://kbs.example.com:8080" {
|
||||
return false
|
||||
}
|
||||
// Verify algorithm source is passed
|
||||
if c.Algorithm.Source == nil ||
|
||||
c.Algorithm.Source.URL != "docker://registry.example.com/algo:v1" ||
|
||||
c.Algorithm.Source.KBSResourcePath != "default/key/algo-key" ||
|
||||
!c.Algorithm.Source.Encrypted {
|
||||
return false
|
||||
}
|
||||
// Verify algorithm type and args
|
||||
if c.Algorithm.AlgoType != "python" || len(c.Algorithm.AlgoArgs) != 1 || c.Algorithm.AlgoArgs[0] != "--verbose" {
|
||||
return false
|
||||
}
|
||||
// Verify dataset source is passed
|
||||
if len(c.Datasets) != 1 ||
|
||||
c.Datasets[0].Source == nil ||
|
||||
c.Datasets[0].Source.URL != "docker://registry.example.com/data:v1" ||
|
||||
c.Datasets[0].Filename != "data.csv" ||
|
||||
!c.Datasets[0].Decompress {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})).Return(nil)
|
||||
mockServerSvc.On("Start", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
err = client.handleRunReqChunks(context.Background(), chunk)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestManagerClient_handleRunReqChunksAlreadyProcessing tests skipping init when already processing.
|
||||
func TestManagerClient_handleRunReqChunksAlreadyProcessing(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id-processing",
|
||||
Name: "test-computation",
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
chunk := &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{
|
||||
Id: "chunk-processing-1",
|
||||
Data: runReqBytes,
|
||||
IsLast: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate agent already processing a computation
|
||||
mockSvc.On("State").Return("Running")
|
||||
|
||||
err = client.handleRunReqChunks(context.Background(), chunk)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// InitComputation should NOT be called since state is not ReceivingManifest
|
||||
mockSvc.AssertNotCalled(t, "InitComputation")
|
||||
}
|
||||
|
||||
+221
-32
@@ -826,6 +826,7 @@ type ComputationRunReq struct {
|
||||
Algorithm *Algorithm `protobuf:"bytes,5,opt,name=algorithm,proto3" json:"algorithm,omitempty"`
|
||||
ResultConsumers []*ResultConsumer `protobuf:"bytes,6,rep,name=result_consumers,json=resultConsumers,proto3" json:"result_consumers,omitempty"`
|
||||
AgentConfig *AgentConfig `protobuf:"bytes,7,opt,name=agent_config,json=agentConfig,proto3" json:"agent_config,omitempty"`
|
||||
Kbs *KBSConfig `protobuf:"bytes,8,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration for remote resources
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -909,6 +910,13 @@ func (x *ComputationRunReq) GetAgentConfig() *AgentConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *ComputationRunReq) GetKbs() *KBSConfig {
|
||||
if x != nil {
|
||||
return x.Kbs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ResultConsumer struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
UserKey []byte `protobuf:"bytes,1,opt,name=userKey,proto3" json:"userKey,omitempty"`
|
||||
@@ -958,6 +966,8 @@ type Dataset struct {
|
||||
Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` // should be sha3.Sum256, 32 byte length.
|
||||
UserKey []byte `protobuf:"bytes,2,opt,name=userKey,proto3" json:"userKey,omitempty"`
|
||||
Filename string `protobuf:"bytes,3,opt,name=filename,proto3" json:"filename,omitempty"`
|
||||
Source *Source `protobuf:"bytes,4,opt,name=source,proto3" json:"source,omitempty"` // Optional remote source for encrypted dataset
|
||||
Decompress bool `protobuf:"varint,5,opt,name=decompress,proto3" json:"decompress,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1013,10 +1023,27 @@ func (x *Dataset) GetFilename() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Dataset) GetSource() *Source {
|
||||
if x != nil {
|
||||
return x.Source
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Dataset) GetDecompress() bool {
|
||||
if x != nil {
|
||||
return x.Decompress
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type Algorithm struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` // should be sha3.Sum256, 32 byte length.
|
||||
UserKey []byte `protobuf:"bytes,2,opt,name=userKey,proto3" json:"userKey,omitempty"`
|
||||
Source *Source `protobuf:"bytes,3,opt,name=source,proto3" json:"source,omitempty"` // Optional remote source for encrypted algorithm
|
||||
AlgoType string `protobuf:"bytes,4,opt,name=algo_type,json=algoType,proto3" json:"algo_type,omitempty"`
|
||||
AlgoArgs []string `protobuf:"bytes,5,rep,name=algo_args,json=algoArgs,proto3" json:"algo_args,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1065,6 +1092,147 @@ func (x *Algorithm) GetUserKey() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Algorithm) GetSource() *Source {
|
||||
if x != nil {
|
||||
return x.Source
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Algorithm) GetAlgoType() string {
|
||||
if x != nil {
|
||||
return x.AlgoType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Algorithm) GetAlgoArgs() []string {
|
||||
if x != nil {
|
||||
return x.AlgoArgs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Source struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` // Type of source: "oci-image" (only OCI images supported for CoCo)
|
||||
Url string `protobuf:"bytes,2,opt,name=url,proto3" json:"url,omitempty"` // URL of the OCI image (e.g., docker://registry/repo:tag)
|
||||
KbsResourcePath string `protobuf:"bytes,3,opt,name=kbs_resource_path,json=kbsResourcePath,proto3" json:"kbs_resource_path,omitempty"` // Path to decryption key in KBS (e.g., "default/key/my-key")
|
||||
Encrypted bool `protobuf:"varint,4,opt,name=encrypted,proto3" json:"encrypted,omitempty"` // Whether the resource is encrypted (requires KBS)
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *Source) Reset() {
|
||||
*x = Source{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[15]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *Source) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Source) ProtoMessage() {}
|
||||
|
||||
func (x *Source) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[15]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Source.ProtoReflect.Descriptor instead.
|
||||
func (*Source) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{15}
|
||||
}
|
||||
|
||||
func (x *Source) GetType() string {
|
||||
if x != nil {
|
||||
return x.Type
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Source) GetUrl() string {
|
||||
if x != nil {
|
||||
return x.Url
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Source) GetKbsResourcePath() string {
|
||||
if x != nil {
|
||||
return x.KbsResourcePath
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Source) GetEncrypted() bool {
|
||||
if x != nil {
|
||||
return x.Encrypted
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type KBSConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` // KBS endpoint URL (e.g., "https://kbs.example.com")
|
||||
Enabled bool `protobuf:"varint,2,opt,name=enabled,proto3" json:"enabled,omitempty"` // Whether to use KBS for key retrieval
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *KBSConfig) Reset() {
|
||||
*x = KBSConfig{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[16]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *KBSConfig) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*KBSConfig) ProtoMessage() {}
|
||||
|
||||
func (x *KBSConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[16]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use KBSConfig.ProtoReflect.Descriptor instead.
|
||||
func (*KBSConfig) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{16}
|
||||
}
|
||||
|
||||
func (x *KBSConfig) GetUrl() string {
|
||||
if x != nil {
|
||||
return x.Url
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *KBSConfig) GetEnabled() bool {
|
||||
if x != nil {
|
||||
return x.Enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Port string `protobuf:"bytes,1,opt,name=port,proto3" json:"port,omitempty"`
|
||||
@@ -1080,7 +1248,7 @@ type AgentConfig struct {
|
||||
|
||||
func (x *AgentConfig) Reset() {
|
||||
*x = AgentConfig{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[15]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[17]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -1092,7 +1260,7 @@ func (x *AgentConfig) String() string {
|
||||
func (*AgentConfig) ProtoMessage() {}
|
||||
|
||||
func (x *AgentConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[15]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[17]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -1105,7 +1273,7 @@ func (x *AgentConfig) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use AgentConfig.ProtoReflect.Descriptor instead.
|
||||
func (*AgentConfig) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{15}
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{17}
|
||||
}
|
||||
|
||||
func (x *AgentConfig) GetPort() string {
|
||||
@@ -1167,7 +1335,7 @@ type AttestationResponse struct {
|
||||
|
||||
func (x *AttestationResponse) Reset() {
|
||||
*x = AttestationResponse{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[16]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[18]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -1179,7 +1347,7 @@ func (x *AttestationResponse) String() string {
|
||||
func (*AttestationResponse) ProtoMessage() {}
|
||||
|
||||
func (x *AttestationResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[16]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[18]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -1192,7 +1360,7 @@ func (x *AttestationResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use AttestationResponse.ProtoReflect.Descriptor instead.
|
||||
func (*AttestationResponse) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{16}
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{18}
|
||||
}
|
||||
|
||||
func (x *AttestationResponse) GetFile() []byte {
|
||||
@@ -1219,7 +1387,7 @@ type AzureAttestationToken struct {
|
||||
|
||||
func (x *AzureAttestationToken) Reset() {
|
||||
*x = AzureAttestationToken{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[17]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[19]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -1231,7 +1399,7 @@ func (x *AzureAttestationToken) String() string {
|
||||
func (*AzureAttestationToken) ProtoMessage() {}
|
||||
|
||||
func (x *AzureAttestationToken) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[17]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[19]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -1244,7 +1412,7 @@ func (x *AzureAttestationToken) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use AzureAttestationToken.ProtoReflect.Descriptor instead.
|
||||
func (*AzureAttestationToken) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{17}
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{19}
|
||||
}
|
||||
|
||||
func (x *AzureAttestationToken) GetFile() []byte {
|
||||
@@ -1317,7 +1485,7 @@ const file_agent_cvms_cvms_proto_rawDesc = "" +
|
||||
"\fRunReqChunks\x12\x12\n" +
|
||||
"\x04data\x18\x01 \x01(\fR\x04data\x12\x0e\n" +
|
||||
"\x02id\x18\x02 \x01(\tR\x02id\x12\x17\n" +
|
||||
"\ais_last\x18\x03 \x01(\bR\x06isLast\"\xaa\x02\n" +
|
||||
"\ais_last\x18\x03 \x01(\bR\x06isLast\"\xcd\x02\n" +
|
||||
"\x11ComputationRunReq\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" +
|
||||
"\x04name\x18\x02 \x01(\tR\x04name\x12 \n" +
|
||||
@@ -1325,16 +1493,32 @@ const file_agent_cvms_cvms_proto_rawDesc = "" +
|
||||
"\bdatasets\x18\x04 \x03(\v2\r.cvms.DatasetR\bdatasets\x12-\n" +
|
||||
"\talgorithm\x18\x05 \x01(\v2\x0f.cvms.AlgorithmR\talgorithm\x12?\n" +
|
||||
"\x10result_consumers\x18\x06 \x03(\v2\x14.cvms.ResultConsumerR\x0fresultConsumers\x124\n" +
|
||||
"\fagent_config\x18\a \x01(\v2\x11.cvms.AgentConfigR\vagentConfig\"*\n" +
|
||||
"\fagent_config\x18\a \x01(\v2\x11.cvms.AgentConfigR\vagentConfig\x12!\n" +
|
||||
"\x03kbs\x18\b \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"*\n" +
|
||||
"\x0eResultConsumer\x12\x18\n" +
|
||||
"\auserKey\x18\x01 \x01(\fR\auserKey\"S\n" +
|
||||
"\auserKey\x18\x01 \x01(\fR\auserKey\"\x99\x01\n" +
|
||||
"\aDataset\x12\x12\n" +
|
||||
"\x04hash\x18\x01 \x01(\fR\x04hash\x12\x18\n" +
|
||||
"\auserKey\x18\x02 \x01(\fR\auserKey\x12\x1a\n" +
|
||||
"\bfilename\x18\x03 \x01(\tR\bfilename\"9\n" +
|
||||
"\bfilename\x18\x03 \x01(\tR\bfilename\x12$\n" +
|
||||
"\x06source\x18\x04 \x01(\v2\f.cvms.SourceR\x06source\x12\x1e\n" +
|
||||
"\n" +
|
||||
"decompress\x18\x05 \x01(\bR\n" +
|
||||
"decompress\"\x99\x01\n" +
|
||||
"\tAlgorithm\x12\x12\n" +
|
||||
"\x04hash\x18\x01 \x01(\fR\x04hash\x12\x18\n" +
|
||||
"\auserKey\x18\x02 \x01(\fR\auserKey\"\xe5\x01\n" +
|
||||
"\auserKey\x18\x02 \x01(\fR\auserKey\x12$\n" +
|
||||
"\x06source\x18\x03 \x01(\v2\f.cvms.SourceR\x06source\x12\x1b\n" +
|
||||
"\talgo_type\x18\x04 \x01(\tR\balgoType\x12\x1b\n" +
|
||||
"\talgo_args\x18\x05 \x03(\tR\balgoArgs\"x\n" +
|
||||
"\x06Source\x12\x12\n" +
|
||||
"\x04type\x18\x01 \x01(\tR\x04type\x12\x10\n" +
|
||||
"\x03url\x18\x02 \x01(\tR\x03url\x12*\n" +
|
||||
"\x11kbs_resource_path\x18\x03 \x01(\tR\x0fkbsResourcePath\x12\x1c\n" +
|
||||
"\tencrypted\x18\x04 \x01(\bR\tencrypted\"7\n" +
|
||||
"\tKBSConfig\x12\x10\n" +
|
||||
"\x03url\x18\x01 \x01(\tR\x03url\x12\x18\n" +
|
||||
"\aenabled\x18\x02 \x01(\bR\aenabled\"\xe5\x01\n" +
|
||||
"\vAgentConfig\x12\x12\n" +
|
||||
"\x04port\x18\x01 \x01(\tR\x04port\x12\x1b\n" +
|
||||
"\tcert_file\x18\x02 \x01(\tR\bcertFile\x12\x19\n" +
|
||||
@@ -1364,7 +1548,7 @@ func file_agent_cvms_cvms_proto_rawDescGZIP() []byte {
|
||||
return file_agent_cvms_cvms_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_cvms_cvms_proto_msgTypes = make([]protoimpl.MessageInfo, 18)
|
||||
var file_agent_cvms_cvms_proto_msgTypes = make([]protoimpl.MessageInfo, 20)
|
||||
var file_agent_cvms_cvms_proto_goTypes = []any{
|
||||
(*AgentStateReq)(nil), // 0: cvms.AgentStateReq
|
||||
(*AgentStateRes)(nil), // 1: cvms.AgentStateRes
|
||||
@@ -1381,21 +1565,23 @@ var file_agent_cvms_cvms_proto_goTypes = []any{
|
||||
(*ResultConsumer)(nil), // 12: cvms.ResultConsumer
|
||||
(*Dataset)(nil), // 13: cvms.Dataset
|
||||
(*Algorithm)(nil), // 14: cvms.Algorithm
|
||||
(*AgentConfig)(nil), // 15: cvms.AgentConfig
|
||||
(*AttestationResponse)(nil), // 16: cvms.AttestationResponse
|
||||
(*AzureAttestationToken)(nil), // 17: cvms.azureAttestationToken
|
||||
(*timestamppb.Timestamp)(nil), // 18: google.protobuf.Timestamp
|
||||
(*Source)(nil), // 15: cvms.Source
|
||||
(*KBSConfig)(nil), // 16: cvms.KBSConfig
|
||||
(*AgentConfig)(nil), // 17: cvms.AgentConfig
|
||||
(*AttestationResponse)(nil), // 18: cvms.AttestationResponse
|
||||
(*AzureAttestationToken)(nil), // 19: cvms.azureAttestationToken
|
||||
(*timestamppb.Timestamp)(nil), // 20: google.protobuf.Timestamp
|
||||
}
|
||||
var file_agent_cvms_cvms_proto_depIdxs = []int32{
|
||||
18, // 0: cvms.AgentEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
18, // 1: cvms.AgentLog.timestamp:type_name -> google.protobuf.Timestamp
|
||||
20, // 0: cvms.AgentEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
20, // 1: cvms.AgentLog.timestamp:type_name -> google.protobuf.Timestamp
|
||||
6, // 2: cvms.ClientStreamMessage.agent_log:type_name -> cvms.AgentLog
|
||||
5, // 3: cvms.ClientStreamMessage.agent_event:type_name -> cvms.AgentEvent
|
||||
4, // 4: cvms.ClientStreamMessage.run_res:type_name -> cvms.RunResponse
|
||||
3, // 5: cvms.ClientStreamMessage.stopComputationRes:type_name -> cvms.StopComputationResponse
|
||||
1, // 6: cvms.ClientStreamMessage.agentStateRes:type_name -> cvms.AgentStateRes
|
||||
16, // 7: cvms.ClientStreamMessage.vTPMattestationReport:type_name -> cvms.AttestationResponse
|
||||
17, // 8: cvms.ClientStreamMessage.azureAttestationToken:type_name -> cvms.azureAttestationToken
|
||||
18, // 7: cvms.ClientStreamMessage.vTPMattestationReport:type_name -> cvms.AttestationResponse
|
||||
19, // 8: cvms.ClientStreamMessage.azureAttestationToken:type_name -> cvms.azureAttestationToken
|
||||
10, // 9: cvms.ServerStreamMessage.runReqChunks:type_name -> cvms.RunReqChunks
|
||||
11, // 10: cvms.ServerStreamMessage.runReq:type_name -> cvms.ComputationRunReq
|
||||
2, // 11: cvms.ServerStreamMessage.stopComputation:type_name -> cvms.StopComputation
|
||||
@@ -1404,14 +1590,17 @@ var file_agent_cvms_cvms_proto_depIdxs = []int32{
|
||||
13, // 14: cvms.ComputationRunReq.datasets:type_name -> cvms.Dataset
|
||||
14, // 15: cvms.ComputationRunReq.algorithm:type_name -> cvms.Algorithm
|
||||
12, // 16: cvms.ComputationRunReq.result_consumers:type_name -> cvms.ResultConsumer
|
||||
15, // 17: cvms.ComputationRunReq.agent_config:type_name -> cvms.AgentConfig
|
||||
7, // 18: cvms.Service.Process:input_type -> cvms.ClientStreamMessage
|
||||
8, // 19: cvms.Service.Process:output_type -> cvms.ServerStreamMessage
|
||||
19, // [19:20] is the sub-list for method output_type
|
||||
18, // [18:19] is the sub-list for method input_type
|
||||
18, // [18:18] is the sub-list for extension type_name
|
||||
18, // [18:18] is the sub-list for extension extendee
|
||||
0, // [0:18] is the sub-list for field type_name
|
||||
17, // 17: cvms.ComputationRunReq.agent_config:type_name -> cvms.AgentConfig
|
||||
16, // 18: cvms.ComputationRunReq.kbs:type_name -> cvms.KBSConfig
|
||||
15, // 19: cvms.Dataset.source:type_name -> cvms.Source
|
||||
15, // 20: cvms.Algorithm.source:type_name -> cvms.Source
|
||||
7, // 21: cvms.Service.Process:input_type -> cvms.ClientStreamMessage
|
||||
8, // 22: cvms.Service.Process:output_type -> cvms.ServerStreamMessage
|
||||
22, // [22:23] is the sub-list for method output_type
|
||||
21, // [21:22] is the sub-list for method input_type
|
||||
21, // [21:21] is the sub-list for extension type_name
|
||||
21, // [21:21] is the sub-list for extension extendee
|
||||
0, // [0:21] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_cvms_cvms_proto_init() }
|
||||
@@ -1441,7 +1630,7 @@ func file_agent_cvms_cvms_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_cvms_cvms_proto_rawDesc), len(file_agent_cvms_cvms_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 18,
|
||||
NumMessages: 20,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -92,6 +92,7 @@ message ComputationRunReq {
|
||||
Algorithm algorithm = 5;
|
||||
repeated ResultConsumer result_consumers = 6;
|
||||
AgentConfig agent_config = 7;
|
||||
KBSConfig kbs = 8; // Optional KBS configuration for remote resources
|
||||
}
|
||||
|
||||
message ResultConsumer {
|
||||
@@ -102,11 +103,28 @@ message Dataset {
|
||||
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
|
||||
bytes userKey = 2;
|
||||
string filename = 3;
|
||||
Source source = 4; // Optional remote source for encrypted dataset
|
||||
bool decompress = 5;
|
||||
}
|
||||
|
||||
message Algorithm {
|
||||
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
|
||||
bytes userKey = 2;
|
||||
Source source = 3; // Optional remote source for encrypted algorithm
|
||||
string algo_type = 4;
|
||||
repeated string algo_args = 5;
|
||||
}
|
||||
|
||||
message Source {
|
||||
string type = 1; // Type of source: "oci-image" (only OCI images supported for CoCo)
|
||||
string url = 2; // URL of the OCI image (e.g., docker://registry/repo:tag)
|
||||
string kbs_resource_path = 3; // Path to decryption key in KBS (e.g., "default/key/my-key")
|
||||
bool encrypted = 4; // Whether the resource is encrypted (requires KBS)
|
||||
}
|
||||
|
||||
message KBSConfig {
|
||||
string url = 1; // KBS endpoint URL (e.g., "https://kbs.example.com")
|
||||
bool enabled = 2; // Whether to use KBS for key retrieval
|
||||
}
|
||||
|
||||
message AgentConfig {
|
||||
|
||||
@@ -18,6 +18,11 @@ func (m *MockAttestationClient) GetAttestation(ctx context.Context, reportData [
|
||||
return args.Get(0).([]byte), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAttestationClient) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
args := m.Called(ctx, reportData, nonce, attType)
|
||||
return args.Get(0).([]byte), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
|
||||
args := m.Called(ctx, nonce)
|
||||
return args.Get(0).([]byte), args.Error(1)
|
||||
|
||||
+418
-10
@@ -9,8 +9,10 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
sync "sync"
|
||||
"time"
|
||||
|
||||
@@ -24,6 +26,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
|
||||
"github.com/ultravioletrs/cocos/pkg/oci"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
@@ -73,7 +76,7 @@ const (
|
||||
algoFilePermission = 0o700
|
||||
)
|
||||
|
||||
const (
|
||||
var (
|
||||
ImaMeasurementsFilePath = "/sys/kernel/security/integrity/ima/ascii_runtime_measurements"
|
||||
ImaPcrIndex = 10
|
||||
)
|
||||
@@ -125,6 +128,10 @@ type Service interface {
|
||||
State() string
|
||||
}
|
||||
|
||||
type OCIClient interface {
|
||||
PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error
|
||||
}
|
||||
|
||||
type agentService struct {
|
||||
mu sync.Mutex
|
||||
computation Computation // Holds the current computation request details.
|
||||
@@ -142,6 +149,7 @@ type agentService struct {
|
||||
resultsConsumed bool // Indicates if the results have been consumed.
|
||||
cancel context.CancelFunc // Cancels the computation context.
|
||||
vmpl int // VMPL at which the Agent is running.
|
||||
ociClient OCIClient
|
||||
}
|
||||
|
||||
var _ Service = (*agentService)(nil)
|
||||
@@ -160,12 +168,21 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, atte
|
||||
vmpl: vmlp,
|
||||
}
|
||||
|
||||
workDir := filepath.Join(os.TempDir(), "cocos-oci")
|
||||
skopeoClient, err := oci.NewSkopeoClient(workDir)
|
||||
if err != nil {
|
||||
logger.Warn("failed to create Skopeo client", "error", err)
|
||||
}
|
||||
svc.ociClient = skopeoClient
|
||||
|
||||
transitions := []statemachine.Transition{
|
||||
{From: Idle, Event: Start, To: ReceivingManifest},
|
||||
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
|
||||
}
|
||||
|
||||
transitions = append(transitions, []statemachine.Transition{
|
||||
{From: ReceivingAlgorithm, Event: RunFailed, To: Failed},
|
||||
{From: ReceivingData, Event: RunFailed, To: Failed},
|
||||
{From: Running, Event: RunComplete, To: ConsumingResults},
|
||||
{From: Running, Event: RunFailed, To: Failed},
|
||||
{From: ConsumingResults, Event: ResultsConsumed, To: Complete},
|
||||
@@ -175,8 +192,8 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, atte
|
||||
sm.AddTransition(t)
|
||||
}
|
||||
|
||||
sm.SetAction(ReceivingAlgorithm, svc.publishEvent(InProgress.String()))
|
||||
sm.SetAction(ReceivingData, svc.publishEvent(InProgress.String()))
|
||||
sm.SetAction(ReceivingAlgorithm, svc.downloadAlgorithmIfRemote)
|
||||
sm.SetAction(ReceivingData, svc.downloadDatasetsIfRemote)
|
||||
sm.SetAction(Running, svc.runComputation)
|
||||
sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String()))
|
||||
sm.SetAction(Complete, svc.publishEvent(Completed.String()))
|
||||
@@ -211,6 +228,38 @@ func (as *agentService) InitComputation(ctx context.Context, cmp Computation) er
|
||||
|
||||
as.computation = cmp
|
||||
|
||||
// Debug: Log manifest details
|
||||
as.logger.Info("received computation manifest",
|
||||
"computation_id", cmp.ID,
|
||||
"kbs_enabled", cmp.KBS.Enabled,
|
||||
"kbs_url", cmp.KBS.URL,
|
||||
"algo_has_source", cmp.Algorithm.Source != nil,
|
||||
"dataset_count", len(cmp.Datasets))
|
||||
|
||||
if cmp.Algorithm.Source != nil {
|
||||
as.logger.Info("algorithm remote source configured",
|
||||
"url", cmp.Algorithm.Source.URL,
|
||||
"kbs_resource_path", cmp.Algorithm.Source.KBSResourcePath)
|
||||
} else {
|
||||
as.logger.Info("algorithm remote source NOT configured - will wait for direct upload")
|
||||
}
|
||||
|
||||
if cmp.KBS.Enabled {
|
||||
as.logger.Info("KBS is ENABLED", "url", cmp.KBS.URL)
|
||||
} else {
|
||||
as.logger.Info("KBS is NOT ENABLED")
|
||||
}
|
||||
|
||||
for i, d := range cmp.Datasets {
|
||||
if d.Source != nil {
|
||||
as.logger.Info("dataset remote source configured",
|
||||
"index", i,
|
||||
"filename", d.Filename,
|
||||
"url", d.Source.URL,
|
||||
"kbs_resource_path", d.Source.KBSResourcePath)
|
||||
}
|
||||
}
|
||||
|
||||
transitions := []statemachine.Transition{}
|
||||
|
||||
if len(cmp.Datasets) == 0 {
|
||||
@@ -276,6 +325,320 @@ func (as *agentService) StopComputation(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadAlgorithmIfRemote automatically downloads the algorithm if it has a remote source.
|
||||
// This is called as an action when entering the ReceivingAlgorithm state.
|
||||
func (as *agentService) downloadAlgorithmIfRemote(state statemachine.State) {
|
||||
as.publishEvent(InProgress.String())(state)
|
||||
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
|
||||
// Debug: Log decision point
|
||||
as.logger.Info("checking if algorithm should be downloaded automatically",
|
||||
"algo_has_source", as.computation.Algorithm.Source != nil,
|
||||
"kbs_enabled", as.computation.KBS.Enabled)
|
||||
|
||||
// Check if algorithm should be downloaded from remote source
|
||||
if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled {
|
||||
as.logger.Info("downloading algorithm from remote source",
|
||||
"url", as.computation.Algorithm.Source.URL,
|
||||
"kbs_resource_path", as.computation.Algorithm.Source.KBSResourcePath)
|
||||
|
||||
// Use background context for download operation
|
||||
ctx := context.Background()
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm")
|
||||
if err != nil {
|
||||
as.runError = fmt.Errorf("failed to download and decrypt algorithm: %w", err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify hash
|
||||
hash := sha3.Sum256(res.Data)
|
||||
if hash != as.computation.Algorithm.Hash {
|
||||
as.runError = fmt.Errorf("algorithm hash mismatch: expected %x, got %x", as.computation.Algorithm.Hash, hash)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// Write algorithm to file
|
||||
currentDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
as.runError = fmt.Errorf("error getting current directory: %w", err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// If a source directory is available (e.g. from OCI extraction), copy all files
|
||||
if res.SourceDir != "" {
|
||||
as.logger.Info("copying extracted algorithm directory", "src", res.SourceDir, "dst", currentDir)
|
||||
// Simple recursive copy (using shell cp for simplicity and reliability on Linux)
|
||||
// Ensure we copy contents of SourceDir into currentDir
|
||||
// Simple recursive copy (using shell cp for simplicity and reliability on Linux)
|
||||
// Ensure we copy contents of SourceDir into currentDir
|
||||
cmd := exec.Command("cp", "-r", res.SourceDir+"/.", currentDir)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
as.runError = fmt.Errorf("error copying algorithm directory: %v, output: %s", err, out)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
f, err := os.Create(filepath.Join(currentDir, "algo"))
|
||||
if err != nil {
|
||||
as.runError = fmt.Errorf("error creating algorithm file: %w", err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := f.Write(res.Data); err != nil {
|
||||
as.runError = fmt.Errorf("error writing algorithm to file: %w", err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
f.Close()
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.Chmod(f.Name(), algoFilePermission); err != nil {
|
||||
as.runError = fmt.Errorf("error changing file permissions: %w", err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
f.Close()
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
as.runError = fmt.Errorf("error closing file: %w", err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
as.algoReceived = true
|
||||
as.algoRequirements = res.Requirements // Store requirements for installation
|
||||
|
||||
// Create datasets directory
|
||||
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
|
||||
as.runError = fmt.Errorf("error creating datasets directory: %w", err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
as.algoType = as.computation.Algorithm.AlgoType
|
||||
if as.algoType == "" {
|
||||
as.algoType = string(algorithm.AlgoTypeBin)
|
||||
}
|
||||
as.algoArgs = as.computation.Algorithm.AlgoArgs
|
||||
|
||||
as.logger.Info("algorithm downloaded and saved successfully", "type", as.algoType, "has_requirements", len(res.Requirements) > 0)
|
||||
as.sm.SendEvent(AlgorithmReceived)
|
||||
} else {
|
||||
// If no remote source, do nothing - wait for direct upload via Algo() RPC call
|
||||
as.logger.Info("algorithm automatic download not triggered, waiting for direct upload",
|
||||
"reason", "no remote source or KBS not enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// downloadDatasetsIfRemote automatically downloads datasets that have remote sources.
|
||||
// This is called as an action when entering the ReceivingData state.
|
||||
func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
|
||||
as.publishEvent(InProgress.String())(state)
|
||||
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
|
||||
// Check if any datasets should be downloaded from remote sources
|
||||
hasRemoteDatasets := false
|
||||
for _, d := range as.computation.Datasets {
|
||||
if d.Source != nil && as.computation.KBS.Enabled {
|
||||
hasRemoteDatasets = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRemoteDatasets {
|
||||
// No remote datasets, wait for direct uploads via Data() RPC calls
|
||||
return
|
||||
}
|
||||
|
||||
// Download all remote datasets
|
||||
ctx := context.Background()
|
||||
for i := len(as.computation.Datasets) - 1; i >= 0; i-- {
|
||||
d := as.computation.Datasets[i]
|
||||
if d.Source != nil && as.computation.KBS.Enabled {
|
||||
as.logger.Info("downloading dataset from remote source", "filename", d.Filename)
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset")
|
||||
if err != nil {
|
||||
as.logger.Error("failed to download and decrypt dataset", "error", err, "filename", d.Filename)
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify hash
|
||||
hash := sha3.Sum256(res.Data)
|
||||
if hash != d.Hash {
|
||||
as.logger.Error("dataset hash mismatch", "filename", d.Filename)
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// Write dataset to file
|
||||
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, d.Filename))
|
||||
if err != nil {
|
||||
as.logger.Error("error creating dataset file", "error", err, "filename", d.Filename)
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
if d.Decompress {
|
||||
if err := internal.UnzipFromMemory(res.Data, algorithm.DatasetsDir); err != nil {
|
||||
as.logger.Error("error decompressing dataset", "error", err, "filename", d.Filename)
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if _, err := f.Write(res.Data); err != nil {
|
||||
as.logger.Error("error writing dataset to file", "error", err, "filename", d.Filename)
|
||||
f.Close()
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
as.logger.Error("error closing file", "error", err, "filename", d.Filename)
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove from pending datasets
|
||||
as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1)
|
||||
as.logger.Info("dataset downloaded and saved successfully", "filename", d.Filename)
|
||||
}
|
||||
}
|
||||
|
||||
// If all datasets are downloaded, send DataReceived event
|
||||
if len(as.computation.Datasets) == 0 {
|
||||
as.logger.Info("all datasets downloaded successfully")
|
||||
as.sm.SendEvent(DataReceived)
|
||||
}
|
||||
// Otherwise, wait for remaining datasets to be uploaded via Data() RPC calls
|
||||
}
|
||||
|
||||
// DecryptedResource holds the data and metadata of a downloaded and decrypted resource.
|
||||
type DecryptedResource struct {
|
||||
Data []byte
|
||||
Requirements []byte
|
||||
SourceDir string
|
||||
}
|
||||
|
||||
// downloadAndDecryptResource downloads and decrypts a resource using OCI images and CoCo Keyprovider.
|
||||
// For OCI images, Skopeo handles download and CoCo Keyprovider handles decryption automatically.
|
||||
func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) {
|
||||
// Determine source type
|
||||
sourceType := source.Type
|
||||
if sourceType == "" {
|
||||
// Infer from URL
|
||||
if strings.HasPrefix(source.URL, "docker://") || strings.HasPrefix(source.URL, "oci:") {
|
||||
sourceType = "oci-image"
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported source URL format: %s (use oci-image type)", source.URL)
|
||||
}
|
||||
}
|
||||
|
||||
switch sourceType {
|
||||
case "oci-image":
|
||||
return as.downloadAndDecryptOCIImage(ctx, source, resourceType)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported source type: %s", sourceType)
|
||||
}
|
||||
}
|
||||
|
||||
// downloadAndDecryptOCIImage downloads and decrypts an OCI image using Skopeo and CoCo Keyprovider.
|
||||
func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) {
|
||||
as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s)",
|
||||
source.URL, source.Encrypted, source.KBSResourcePath))
|
||||
|
||||
// Create Skopeo client
|
||||
if as.ociClient == nil {
|
||||
return nil, fmt.Errorf("OCI client not initialized")
|
||||
}
|
||||
|
||||
// Create OCI resource source
|
||||
ociSource := oci.ResourceSource{
|
||||
Type: oci.ResourceTypeOCIImage,
|
||||
URI: source.URL,
|
||||
Encrypted: source.Encrypted,
|
||||
KBSResourcePath: source.KBSResourcePath,
|
||||
}
|
||||
|
||||
// Pull and decrypt image
|
||||
// CoCo Keyprovider will automatically handle decryption via ocicrypt
|
||||
// Sanitize directory name to avoid Skopeo interpreting ':' as tag separator
|
||||
sanitizedName := strings.ReplaceAll(filepath.Base(source.URL), ":", "_")
|
||||
destDir := filepath.Join(os.TempDir(), "cocos-oci", "images", sanitizedName)
|
||||
if err := as.ociClient.PullAndDecrypt(ctx, ociSource, destDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull and decrypt OCI image: %w", err)
|
||||
}
|
||||
|
||||
as.logger.Info("OCI image downloaded and decrypted", "dest", destDir)
|
||||
|
||||
// Extract algorithm file from OCI layers
|
||||
extractDir := filepath.Join(os.TempDir(), "cocos-oci", "extracted", sanitizedName)
|
||||
var algorithmPath string
|
||||
var err error
|
||||
|
||||
if resourceType == "algorithm" {
|
||||
algorithmPath, err = oci.ExtractAlgorithm(ctx, as.logger, destDir, extractDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract algorithm from OCI image: %w", err)
|
||||
}
|
||||
as.logger.Info("algorithm extracted from OCI image", "path", algorithmPath)
|
||||
} else {
|
||||
// Assume dataset
|
||||
files, err := oci.ExtractDataset(destDir, extractDir)
|
||||
if err != nil || len(files) == 0 {
|
||||
return nil, fmt.Errorf("failed to extract dataset from OCI image: %w", err)
|
||||
}
|
||||
// For now, take the first file found.
|
||||
// nolint:godox // TODO: Handle multiple files / directory structure if needed.
|
||||
algorithmPath = files[0]
|
||||
as.logger.Info("dataset extracted from OCI image", "path", algorithmPath)
|
||||
}
|
||||
|
||||
// Read algorithm file
|
||||
algorithmData, err := os.ReadFile(algorithmPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read algorithm file: %w", err)
|
||||
}
|
||||
|
||||
// Check for requirements.txt if algorithm
|
||||
var reqData []byte
|
||||
if resourceType == "algorithm" {
|
||||
reqPath := filepath.Join(filepath.Dir(algorithmPath), "requirements.txt")
|
||||
if data, err := os.ReadFile(reqPath); err == nil {
|
||||
reqData = data
|
||||
as.logger.Info("found requirements.txt", "size", len(data))
|
||||
}
|
||||
}
|
||||
|
||||
as.logger.Info("algorithm loaded", "size", len(algorithmData))
|
||||
|
||||
return &DecryptedResource{
|
||||
Data: algorithmData,
|
||||
Requirements: reqData,
|
||||
SourceDir: filepath.Dir(algorithmPath),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
if as.sm.GetState() != ReceivingAlgorithm {
|
||||
return ErrStateNotReady
|
||||
@@ -286,7 +649,25 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
return ErrAllManifestItemsReceived
|
||||
}
|
||||
|
||||
hash := sha3.Sum256(algo.Algorithm)
|
||||
var algoData []byte
|
||||
|
||||
// Check if algorithm should be downloaded from remote source
|
||||
if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled {
|
||||
as.logger.Info("downloading algorithm from remote source")
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download and decrypt algorithm: %w", err)
|
||||
}
|
||||
|
||||
algoData = res.Data
|
||||
as.algoRequirements = res.Requirements
|
||||
} else {
|
||||
// Use directly uploaded algorithm
|
||||
algoData = algo.Algorithm
|
||||
}
|
||||
|
||||
hash := sha3.Sum256(algoData)
|
||||
|
||||
if hash != as.computation.Algorithm.Hash {
|
||||
return ErrHashMismatch
|
||||
@@ -302,7 +683,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
return fmt.Errorf("error creating algorithm file: %v", err)
|
||||
}
|
||||
|
||||
if _, err := f.Write(algo.Algorithm); err != nil {
|
||||
if _, err := f.Write(algoData); err != nil {
|
||||
return fmt.Errorf("error writing algorithm to file: %v", err)
|
||||
}
|
||||
|
||||
@@ -347,28 +728,55 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
|
||||
return ErrAllManifestItemsReceived
|
||||
}
|
||||
|
||||
hash := sha3.Sum256(dataset.Dataset)
|
||||
var datasetData []byte
|
||||
var datasetFilename string
|
||||
|
||||
// Check if any dataset should be downloaded from remote source
|
||||
matchedIndex := -1
|
||||
for i, d := range as.computation.Datasets {
|
||||
if d.Source != nil && as.computation.KBS.Enabled {
|
||||
as.logger.Info("downloading dataset from remote source", "filename", d.Filename)
|
||||
|
||||
downloadedData, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download and decrypt dataset: %w", err)
|
||||
}
|
||||
|
||||
datasetData = downloadedData.Data
|
||||
datasetFilename = d.Filename
|
||||
matchedIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no remote dataset, use uploaded dataset
|
||||
if matchedIndex == -1 {
|
||||
datasetData = dataset.Dataset
|
||||
datasetFilename = dataset.Filename
|
||||
}
|
||||
|
||||
hash := sha3.Sum256(datasetData)
|
||||
|
||||
matched := false
|
||||
for i, d := range as.computation.Datasets {
|
||||
if hash == d.Hash {
|
||||
if d.Filename != "" && d.Filename != dataset.Filename {
|
||||
if d.Filename != "" && d.Filename != datasetFilename {
|
||||
return ErrFileNameMismatch
|
||||
}
|
||||
|
||||
as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1)
|
||||
|
||||
if DecompressFromContext(ctx) {
|
||||
if err := internal.UnzipFromMemory(dataset.Dataset, algorithm.DatasetsDir); err != nil {
|
||||
if err := internal.UnzipFromMemory(datasetData, algorithm.DatasetsDir); err != nil {
|
||||
return fmt.Errorf("error decompressing dataset: %v", err)
|
||||
}
|
||||
} else {
|
||||
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, dataset.Filename))
|
||||
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, datasetFilename))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating dataset file: %v", err)
|
||||
}
|
||||
|
||||
if _, err := f.Write(dataset.Dataset); err != nil {
|
||||
if _, err := f.Write(datasetData); err != nil {
|
||||
return fmt.Errorf("error writing dataset to file: %v", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
|
||||
@@ -3,10 +3,14 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -19,6 +23,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
agentevents "github.com/ultravioletrs/cocos/agent/events"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
@@ -26,11 +31,21 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/oci"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
type MockOCIClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockOCIClient) PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error {
|
||||
args := m.Called(ctx, source, destDir)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
var (
|
||||
algoPath = "../test/manual/algo/lin_reg.py"
|
||||
reqPath = "../test/manual/algo/requirements.txt"
|
||||
@@ -672,3 +687,550 @@ func TestStopComputationConcurrent(t *testing.T) {
|
||||
|
||||
assert.True(t, len(errors) < numGoroutines, "All StopComputation calls failed")
|
||||
}
|
||||
|
||||
// newTestAgentService creates a minimal agentService for direct method testing.
|
||||
func newTestAgentService(sm statemachine.StateMachine, eventSvc agentevents.Service) *agentService {
|
||||
return &agentService{
|
||||
logger: slog.Default(),
|
||||
eventSvc: eventSvc,
|
||||
sm: sm,
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndDecryptResource(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", mock.Anything).Return().Maybe()
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("unsupported URL format no type", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "http://unsupported-format"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported source URL format")
|
||||
})
|
||||
|
||||
t.Run("ftp URL unsupported format", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "ftp://some-server/file"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported source URL format")
|
||||
})
|
||||
|
||||
t.Run("unsupported explicit source type", func(t *testing.T) {
|
||||
source := &ResourceSource{Type: "s3-bucket", URL: "s3://mybucket/algo"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported source type: s3-bucket")
|
||||
})
|
||||
|
||||
t.Run("docker:// URL inferred as oci-image routes to skopeo", func(t *testing.T) {
|
||||
// This exercises the oci-image path; will fail at skopeo step
|
||||
source := &ResourceSource{URL: "docker://invalid.example.com/algo:latest"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
require.Error(t, err)
|
||||
// Should be a skopeo or OCI error, not an "unsupported" error
|
||||
assert.NotContains(t, err.Error(), "unsupported source URL format")
|
||||
})
|
||||
|
||||
t.Run("oci: URL inferred as oci-image routes to skopeo", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "oci:some-local-dir"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.NotContains(t, err.Error(), "unsupported source URL format")
|
||||
})
|
||||
|
||||
t.Run("explicit oci-image type routes to skopeo", func(t *testing.T) {
|
||||
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/algo:latest"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.NotContains(t, err.Error(), "unsupported source type")
|
||||
})
|
||||
|
||||
t.Run("dataset resource type with oci-image", func(t *testing.T) {
|
||||
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/data:latest"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "dataset")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDownloadAlgorithmIfRemote(t *testing.T) {
|
||||
t.Run("no source configured - no-op, waits for direct upload", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
// No SendEvent expected — just the no-op path
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{} // Algorithm.Source == nil
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
assert.Nil(t, svc.runError)
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("source set but KBS disabled - no-op", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Source: &ResourceSource{URL: "docker://registry/algo:latest"},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: false},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
assert.Nil(t, svc.runError)
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("source + KBS enabled - download fails, sends RunFailed", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://invalid.example.com/algo:latest",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
assert.NotNil(t, svc.runError)
|
||||
assert.Contains(t, svc.runError.Error(), "failed to download and decrypt algorithm")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("unsupported URL format - download fails, sends RunFailed", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Source: &ResourceSource{
|
||||
URL: "http://unsupported-format/algo",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
assert.NotNil(t, svc.runError)
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDownloadDatasetsIfRemote(t *testing.T) {
|
||||
t.Run("no datasets with remote sources - no-op", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
// Dataset with no Source
|
||||
dataHash := sha3.Sum256([]byte("testdata"))
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{Hash: dataHash, Filename: "data.csv"},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
// No RunFailed event, no DataReceived event
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("no datasets at all - no-op", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("KBS disabled even with source - no-op", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Source: &ResourceSource{URL: "docker://registry/data:latest"},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: false},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("remote dataset + KBS enabled - download fails, sends RunFailed", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://invalid.example.com/data:latest",
|
||||
},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("unsupported URL fails - sends RunFailed", func(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Source: &ResourceSource{
|
||||
URL: "ftp://unsupported/data",
|
||||
},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunComputation(t *testing.T) {
|
||||
// Helper to set up a temp working directory and restore CWD afterwards.
|
||||
withTempDir := func(t *testing.T) (tmpDir string, restore func()) {
|
||||
t.Helper()
|
||||
origDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
tmpDir = t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
return tmpDir, func() { _ = os.Chdir(origDir) }
|
||||
}
|
||||
|
||||
t.Run("algo file not found sends RunFailed", func(t *testing.T) {
|
||||
_, restore := withTempDir(t)
|
||||
defer restore()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
// No algo file exists – runComputation should hit the ReadFile error path.
|
||||
svc.runComputation(Running)
|
||||
|
||||
assert.Error(t, svc.runError)
|
||||
assert.Contains(t, svc.runError.Error(), "failed to read algo file")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("runner client returns error sends RunFailed", func(t *testing.T) {
|
||||
_, restore := withTempDir(t)
|
||||
defer restore()
|
||||
|
||||
// Write a dummy algo file so ReadFile succeeds.
|
||||
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
|
||||
|
||||
runnerCli := new(runnermocks.Client)
|
||||
runnerCli.On("Run", mock.Anything, mock.Anything).Return((*runnerpb.RunResponse)(nil), fmt.Errorf("runner unavailable"))
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.runnerClient = runnerCli
|
||||
|
||||
svc.runComputation(Running)
|
||||
|
||||
assert.Error(t, svc.runError)
|
||||
assert.Contains(t, svc.runError.Error(), "runner unavailable")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("runner returns non-empty error field sends RunFailed", func(t *testing.T) {
|
||||
_, restore := withTempDir(t)
|
||||
defer restore()
|
||||
|
||||
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
|
||||
|
||||
runnerCli := new(runnermocks.Client)
|
||||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{Error: "computation crashed"}, nil)
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.runnerClient = runnerCli
|
||||
|
||||
svc.runComputation(Running)
|
||||
|
||||
assert.Error(t, svc.runError)
|
||||
assert.Contains(t, svc.runError.Error(), "computation crashed")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIMAMeasurements(t *testing.T) {
|
||||
t.Run("error when IMA measurements file does not exist in non-SGX environment", func(t *testing.T) {
|
||||
// In a regular test environment (non-SGX), the IMA measurements file
|
||||
// at /sys/kernel/security/integrity/ima/ascii_runtime_measurements won't exist.
|
||||
// Verify our error handling works correctly.
|
||||
origPath := ImaMeasurementsFilePath
|
||||
ImaMeasurementsFilePath = "/non/existent/path"
|
||||
defer func() { ImaMeasurementsFilePath = origPath }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
|
||||
data, pcr10, err := svc.IMAMeasurements(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error reading Linux IMA measurements file")
|
||||
assert.Nil(t, data)
|
||||
assert.Nil(t, pcr10)
|
||||
})
|
||||
|
||||
t.Run("successful reading of IMA measurements", func(t *testing.T) {
|
||||
tempFile := filepath.Join(t.TempDir(), "ima_measurements")
|
||||
content := []byte("10 sha1:0000000000000000000000000000000000000000 ima-ng sha256:0000000000000000000000000000000000000000000000000000000000000000 /usr/bin/python3\n")
|
||||
err := os.WriteFile(tempFile, content, 0o644)
|
||||
require.NoError(t, err)
|
||||
vtpm.ExternalTPM = &vtpm.DummyRWC{}
|
||||
|
||||
origPath := ImaMeasurementsFilePath
|
||||
ImaMeasurementsFilePath = tempFile
|
||||
defer func() { ImaMeasurementsFilePath = origPath }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
|
||||
data, pcr10, err := svc.IMAMeasurements(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, content, data)
|
||||
assert.NotEmpty(t, pcr10)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
|
||||
// Skip this test in short mode as it might involve more setup if we were using real OCI
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", AlgorithmReceived).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
algoContent := []byte("print('hello')")
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "main.py", string(algoContent))
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
algoContent = []byte("print('hello')")
|
||||
algoHash := sha3.Sum256(algoContent)
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Hash: algoHash,
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/image",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
// We need to bypass oci.ExtractAlgorithm by manually creating what it would create
|
||||
// OR use a real-enough looking OCI layout.
|
||||
// Since we can't easily mock oci.ExtractAlgorithm, we'll try to provide a minimal OCI layout
|
||||
// so that oci.ExtractAlgorithm doesn't fail.
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
|
||||
assert.Nil(t, svc.runError)
|
||||
assert.True(t, svc.algoReceived)
|
||||
sm.AssertExpectations(t)
|
||||
mockOCI.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func setupMinimalOCI(t *testing.T, ociDir, filename, content string) {
|
||||
t.Helper()
|
||||
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
|
||||
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
|
||||
|
||||
layerPath := filepath.Join(blobsDir, "layer123")
|
||||
layerFile, err := os.Create(layerPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
gw := gzip.NewWriter(layerFile)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
hdr := &tar.Header{
|
||||
Name: filename,
|
||||
Mode: 0o755,
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
require.NoError(t, tw.WriteHeader(hdr))
|
||||
_, err = tw.Write([]byte(content))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, tw.Close())
|
||||
require.NoError(t, gw.Close())
|
||||
require.NoError(t, layerFile.Close())
|
||||
|
||||
manifest := struct {
|
||||
Layers []struct {
|
||||
Digest string `json:"digest"`
|
||||
} `json:"layers"`
|
||||
}{
|
||||
Layers: []struct {
|
||||
Digest string `json:"digest"`
|
||||
}{{Digest: "sha256:layer123"}},
|
||||
}
|
||||
manifestData, err := json.Marshal(manifest)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
|
||||
|
||||
index := oci.OCIIndex{
|
||||
SchemaVersion: 2,
|
||||
Manifests: []struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int `json:"size"`
|
||||
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
|
||||
}
|
||||
indexData, err := json.Marshal(index)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
}
|
||||
|
||||
func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", DataReceived).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
dataContent := []byte("a,b,c\n1,2,3")
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "data.csv", string(dataContent))
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
dataContent = []byte("a,b,c\n1,2,3")
|
||||
dataHash := sha3.Sum256(dataContent)
|
||||
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Hash: dataHash,
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/image",
|
||||
},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
|
||||
assert.Nil(t, svc.runError)
|
||||
assert.Len(t, svc.computation.Datasets, 0)
|
||||
sm.AssertExpectations(t)
|
||||
mockOCI.AssertExpectations(t)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user