NOISSUE - Agent Pull mode for remote resources (#575)
CI / checkproto (push) Has been cancelled
CI / lint (push) Has been cancelled
Rust CI Pipeline / rust-check (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled

* feat(kbs): implement KBS client for attestation and resource retrieval

- Added KBS client implementation in pkg/kbs/client.go with methods for attestation and resource retrieval.
- Introduced necessary data structures for requests and responses.
- Implemented error handling for various scenarios.

test(kbs): add unit tests for KBS client

- Created comprehensive tests for the KBS client in pkg/kbs/client_test.go.
- Included tests for attestation success and failure cases, as well as resource retrieval.

feat(registry): introduce HTTP and S3 registry implementations

- Added HTTPRegistry for downloading resources over HTTP/HTTPS with retry logic in pkg/registry/http.go.
- Implemented S3Registry for downloading resources from AWS S3 and S3-compatible services in pkg/registry/s3.go.
- Included error handling and configuration options for both registries.

chore(registry): define registry interface and configuration

- Created registry interface and configuration struct in pkg/registry/registry.go.
- Added default configuration settings for registry clients.

docs(cvms): update README for CVMS server configuration and usage

- Enhanced documentation for CVMS server with detailed command-line flags and usage examples.
- Clarified direct upload and remote resource modes, including KBS integration.

fix(cvms): integrate KBS for remote resource handling in main.go

- Updated main.go to support remote datasets and algorithms using KBS.
- Added validation for command-line flags to ensure proper configuration.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Move ifeq conditional outside define block in attestation-service.mk

Make conditionals cannot be evaluated inside define...endef blocks
when used as recipe bodies. Restructured to define the
ATTESTATION_SERVICE_INSTALL_INIT_SYSTEMD block conditionally based
on BR2_PACKAGE_CC_ATTESTATION_AGENT configuration.

* feat: Implement remote resource downloading for algorithms and datasets using AWS S3/MinIO credentials.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add comprehensive documentation and agent support for testing remote resource download with KBS attestation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Improve agent logging for remote resource configuration and KBS status, and add a testing guide for remote resource downloads with KBS attestation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add a comprehensive guide for testing remote resource download with KBS attestation and update multiple package versions to a specific commit.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add failure transitions for resource reception states and a comprehensive guide for testing remote resource downloads with KBS attestation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Implement remote resource download with KBS attestation in the agent and add a comprehensive testing guide.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: Add comprehensive guide for testing remote resource download with KBS attestation and include a debug log in the attestation client.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Delegate KBS attestation and token retrieval to a new attestation-agent service and document remote resource testing.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* client fixes

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* raw evidence

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Build all Go files in cmd directories, not just main.go

This fixes the issue where fetch_raw_evidence.go wasn't being included
in the attestation-service build.

* fix: Wrap binary evidence in JSON for KBS compatibility

Fixes 'invalid character' error by wrapping raw binary evidence
in a JSON structure with base64 encoding, as expected by KBS.

* chore: Update buildroot packages to c28cefae

Includes fixes for:
1. attestation-service build (including fetch_raw_evidence.go)
2. Agent KBS evidence format (wrapping binary in JSON)

* fix: Implement KBS RCAR handshake with cookies

Fixes 'cookie not found' error (401) from KBS by:
1. Adding CookieJar support to KBS client
2. Implementing GetChallenge() to perform /auth handshake and capture session cookie
3. Updating Agent to get challenge, decode nonce, and use it for evidence generation
4. Regenerating mocks

* chore: Update buildroot packages to f6981ac5

Includes KBS RCAR handshake fix (cookie support + GetChallenge loop)

* fix: Update KBS client JSON tags to kebab-case

Fixes deserialization error (401) from KBS by:
1. Using kebab-case (e.g. extra-params) for JSON tags as per protocol.
2. Initializing ExtraParams as empty object {} instead of null/omitted.

* fix: Wrap attestation evidence in primary_evidence format

Updates Agent to construct 'tee-evidence' payload with:
- primary_evidence: containing the actual quote/data
- additional_evidence: empty JSON object

This matches the Confidential Containers KBS Attestation Protocol requirements.

* fix: Update KBS protocol version to 0.4.0

KBS rejected 0.1.0 with a version mismatch error. Bumping to 0.4.0 to match server expectation.

* fix: Generate ephemeral key for KBS RuntimeData

Updates RuntimeData to include a valid ephemeral EC P-256 public key in JWK format, as required by the KBS RCAR protocol.
Also fixes the KBS client struct to support TEEPubKey as an object.

* fix: Update sample attestation quote to valid JSON

The default attestation.bin was binary, but the KBS Sample Verifier expects a valid JSON quote containing 'svn' and 'report_data'.
Updated the embedded bin file to contain this JSON structure.

* fix: Generate dynamic JSON quote for Sample TEE in FetchRawEvidence

The KBS Sample Verifier expects a JSON object with 'svn' and 'report_data'.
Previously, we were returning raw binary data (reportData+nonce).
This commit updates FetchRawEvidence to return a marshaled JSON structure with:
- svn: "1"
- report_data: base64(req.ReportData)

* refactor: Delegate Sample Attestation to Provider

Refactored sample attestation logic:
- Moved JSON Quote generation into EmptyProvider (standalone mode).
- Updated FetchRawEvidence to call provider.TeeAttestation instead of manual generation.
This enables using the real CC Attestation Agent for UNSPECIFIED platform if configured.

* feat: Add comprehensive debug logging and enforce CC AA usage

Changes:
- Updated EmptyProvider to return error instead of generating mock data
  This forces proper use of CC Attestation Agent's sample attester
- Added detailed logging to attestation-service FetchRawEvidence:
  * Hex dump of evidence (first 200 bytes)
  * String preview of evidence
  * Total evidence length
- Added detailed logging to agent service:
  * Raw evidence hex and string previews
  * KBS evidence JSON preview (first 500 bytes)
  * Evidence lengths at each transformation step

This logging will help diagnose why KBS Sample Verifier is rejecting evidence.

* fix: Enable CC AA by default and add attestation-service log forwarding

Changes:
- Set USE_CC_ATTESTATION_AGENT=true by default in systemd service
- Added StandardOutput/StandardError to forward logs to /var/log/cocos/
- Updated HAL makefile to handle new default value
- This ensures attestation-service uses CC AA's sample attester
- Logs will now be visible in CVMS output for debugging

* feat: Add gRPC log forwarding to attestation-service

Implemented the same log forwarding mechanism used by the agent:
- Added ProtoHandler to write logs to both stdout and logQueue
- Connected to log client (/run/cocos/log.sock) for gRPC forwarding
- Added goroutine to forward logs to CVMS via log client
- Logs will now appear in CVMS output during computation runs

This enables visibility into attestation-service debug output including:
- CC AA connection status
- Evidence generation details (hex dumps, string previews)
- Any errors from providers

* fix: Parse sample evidence JSON instead of base64-encoding it

The attestation-service returns sample evidence as JSON:
{"svn":"1","report_data":"base64..."}

The agent was incorrectly base64-encoding this JSON string again.
KBS Sample Verifier expects the parsed JSON object directly.

Fixed by:
- Parsing the JSON evidence from attestation-service
- Passing the parsed object directly in primary_evidence.evidence
- This matches what KBS Sample Verifier expects

* debug: Increase KBS evidence logging preview to 1000 bytes

Show the complete JSON structure being sent to KBS to debug
the attestation failure.

* debug: Add comprehensive CC AA configuration logging

Added debug logs to show:
- Whether CC AA is enabled in config
- CC AA address being used
- Connection success/failure
- Which provider is ultimately selected
- Warning when falling back to EmptyProvider

This will help diagnose why EmptyProvider is being used
instead of CC Attestation Agent.

* debug: Add startup logging for log client connection

Added log message to show if log client connection succeeds
at attestation-service startup. This will help diagnose why
logs aren't appearing in CVMS output.

* feat: Add retry logic with exponential backoff to log client

Added simple retry mechanism to handle concurrent log requests:
- 3 retry attempts with exponential backoff (10ms, 20ms, 40ms)
- Applies to both SendLog and SendEvent methods
- Centralized in log client so all services benefit
- Should eliminate 'failed to send log' errors from concurrent requests

This fixes the issue where attestation-service logs weren't
appearing in CVMS output due to dropped messages.

* fix: Flatten sample evidence fields in primary_evidence for KBS

KBS Sample Verifier expects svn and report_data at the top level
of primary_evidence, not nested under an 'evidence' key.

Changed structure from:
{"primary_evidence": {"tee": "sample", "evidence": {"svn": "1", ...}}}

To:
{"primary_evidence": {"tee": "sample", "svn": "1", "report_data": "...", ...}}

This matches what KBS expects when deserializing the Quote structure.

* fix: Use sample quote directly as primary_evidence per KBS protocol

According to KBS attestation protocol spec, for sample TEE type,
primary_evidence should be the sample quote JSON directly:
{"svn": "1", "report_data": "..."}

Removed extra 'tee' and 'platform' fields that were causing KBS
to fail deserializing the Quote structure. The 'tee' field is
already sent in the Request payload during RCAR handshake.

Refs:
- https://github.com/confidential-containers/trustee/blob/main/kbs/docs/kbs_attestation_protocol.md
- https://github.com/confidential-containers/guest-components/blob/main/attestation-agent/attester/src/sample/mod.rs

* fix: Make CC AA required for sample attestation when configured

When USE_CC_ATTESTATION_AGENT=true, attestation-service now
requires AA to be available for NoCC/sample platform. This ensures
sample evidence always comes from AA with the correct KBS format.

Changes:
- Error out if AA connection fails for NoCC platform when AA is configured
- Only use EmptyProvider if AA is explicitly NOT configured
- Prevents incorrect sample evidence format from EmptyProvider

This ensures attestation-service delegates to AA for sample evidence
generation instead of creating it itself.

* fix: Implement proper RCAR protocol with tee-pubkey and runtime-data hash

Fixed KBS attestation error 'REPORT_DATA is different from that in Sample Quote'

Changes:
1. Generate ephemeral EC key pair BEFORE getting evidence from AA
2. Create runtime-data with nonce + tee-pubkey (JWK format)
3. Hash runtime-data (SHA-256) and use as report_data for AA
4. This binds the tee-pubkey to the TEE evidence per RCAR protocol

The report_data in the evidence now matches what KBS expects:
hash(runtime-data) instead of computation ID.

This completes the full RCAR protocol implementation:
- Request → Challenge → Attestation (with bound tee-pubkey) → Response

* fix(agent): use simple nonce for Sample attestation report_data

For Sample/NoCC attestation, use the raw nonce bytes directly as
report_data instead of hashing runtime-data. This avoids JSON
serialization mismatches with the KBS Sample verifier.

Real TEEs (TDX/SNP) still use runtime-data hash binding to
cryptographically bind the ephemeral tee-pubkey to the evidence.

* fix(agent): use RFC 8785 canonical JSON for runtime-data hashing

The KBS Sample attestation verifier (and likely others) expects the
report_data to be the SHA-256 hash of the *canonical* JSON serialization
(RFC 8785) of the runtime-data. Standard Go JSON marshaling does not
guarantee key ordering, leading to hash mismatches.

This change uses github.com/gowebpki/jcs to canonicalize the runtime-data
before hashing, ensuring compatibility with the KBS RCAR implementation.
Also reverted the temporary 'simple nonce' workaround.

* feat(hal): add CoCo Keyprovider and Skopeo packages

- Add coco-keyprovider buildroot package with systemd service
- Add skopeo buildroot package for OCI image handling
- Add ocicrypt_keyprovider.conf for encrypted image decryption
- Update Config.in to include new packages

This enables standard CoCo ecosystem integration for encrypted
OCI images instead of custom S3/HTTP registry clients.

* feat(oci): add OCI image handling package with Skopeo integration

- Add pkg/oci/types.go with ResourceSource and ImageManifest types
- Add pkg/oci/skopeo.go with Skopeo wrapper for pull/decrypt
- Add pkg/oci/extract.go for extracting algorithms and datasets from layers

This package provides OCI image handling using Skopeo and CoCo
Keyprovider for encrypted image decryption, replacing custom
S3/HTTP registry clients.

* chore: regenerate protobuf files for updated cvms.proto

* refactor(agent): replace S3/HTTP/KBS with OCI package

- Remove pkg/kbs and pkg/registry imports
- Add pkg/oci import for OCI image handling
- Replace downloadAndDecryptResource with OCI-based implementation
- Use Skopeo + CoCo Keyprovider for automatic decryption
- Reduce code from ~240 lines to ~70 lines

This eliminates custom KBS RCAR handshake, S3/HTTP registry clients,
and manual decryption logic. CoCo Keyprovider handles all decryption
automatically via ocicrypt protocol.

* chore: remove obsolete pkg/kbs and pkg/registry packages

- Delete pkg/kbs/ (custom KBS client, ~300 lines)
- Delete pkg/registry/ (S3/HTTP registry clients, ~400 lines)
- Remove unused imports from agent/service.go
- Run go mod tidy to clean up dependencies

These packages have been replaced by pkg/oci with Skopeo and
CoCo Keyprovider for standard CoCo ecosystem integration.

* fix(agent): update ResourceSource struct to include type and encryption fields

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix(hal): update CoCo Keyprovider to v0.16.0 and fix build path

- Update version from v0.11.0 to v0.16.0 (matches attestation agent)
- Fix install path: target is at repo root, not in coco_keyprovider subdir
- This fixes the build error where coco_keyprovider binary wasn't found

The cargo workspace in guest-components builds to a shared target/
directory at the repository root, not within each crate's subdirectory.

* feat: Update remote resources testing guide to use kbs-client and coco-keyprovider for key management and encryption, enable insecure TLS for Skopeo, and enhance CVMS with

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update component versions, revise image encryption documentation, and sanitize OCI image paths for Skopeo compatibility.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add `decompress` option to Dataset and `algo_type`/`algo_args` to Algorithm protobuf messages, updating client, test, and build configurations.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Update multiple package versions and enhance OCI image extraction error reporting for missing algorithm files.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Bump package versions, improve OCI image extraction debugging by returning seen files, and remove unused dataset type parsing from test code.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Migrate OCI extraction to use structured logging with `slog` and `context`, and update package versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Bump multiple component versions, add encrypted status for computation inputs and algorithms, and refine OCI layer extraction warnings.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* logging

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add `Encrypted` field to algorithm and dataset resource sources and update all component versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: update component versions, integrate coco-keyprovider service, and configure ocicrypt key provider.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: add support for KBS parameters and dataset/algorithm hash calculations in CVMS

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: update resource download and extraction logic to support requirements.txt and improve hash verification

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Update dependencies, improve code style, and add GetRawEvidence to attestation client mocks.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor code structure for improved readability and maintainability

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: update golangci configuration to include errcheck for build path and remove unnecessary exclusions

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: streamline kernel command line handling in QEMU args construction

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: add attestation binary and update checksum tests and policy structure

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for attestation agent, attestation, log, crypto, OCI, and Skopeo clients

- Implement tests for the attestation agent client including Unix socket and TCP address handling, token retrieval, and error scenarios.
- Enhance attestation client tests to cover fetching raw evidence for various platforms (SNP, TDX, VTPM, SNPvTPM) and validate error handling.
- Introduce log client tests to verify retry behavior for sending logs and events.
- Create comprehensive tests for crypto package focusing on AES-GCM decryption, encrypted resource parsing, and key unwrapping.
- Add tests for OCI package to validate algorithm and dataset extraction, including JSON serialization of OCILayout.
- Implement Skopeo client tests to ensure proper functionality for image pulling, inspecting, and resource source handling.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: handle JSON marshal errors in test cases for decrypt and extract functions

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add comprehensive tests for algorithm and dataset extraction with various scenarios

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: replace hardcoded Python script content with constant variable

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: remove redundant mock expectation for SendAgentConfig in TestCreateVMWithAaKbsParams

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add tests for event sending failure, dataset extraction with path traversal, and Skopeo client behavior

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add tests for download and decryption of resources with various URL formats

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Introduce OCIClient interface for agent service to improve testability of OCI image operations and enhance related tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Change `get_uint64_from_tcb` to accept `TcbVersion` by value and use `u64::from` for type conversions.

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2026-03-16 16:48:55 +03:00
committed by GitHub
parent f77ec5644a
commit da31d76c94
76 changed files with 7464 additions and 392 deletions
+8
View File
@@ -57,6 +57,14 @@ func (m *mockAttestationClient) GetAttestation(ctx context.Context, reportData [
return args.Get(0).([]byte), args.Error(1)
}
func (m *mockAttestationClient) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
args := m.Called(ctx, reportData, nonce, attType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]byte), args.Error(1)
}
func (m *mockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
args := m.Called(ctx, nonce)
if args.Get(0) == nil {
+10 -3
View File
@@ -3,18 +3,25 @@
package attestation
import cocosai "github.com/ultravioletrs/cocos"
import (
"fmt"
cocosai "github.com/ultravioletrs/cocos"
)
var _ Provider = (*EmptyProvider)(nil)
type EmptyProvider struct{}
func (e *EmptyProvider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error) {
return cocosai.EmbeddedAttestation, nil
// For Sample/Empty provider, we treat the teeNonce as reportData
return e.TeeAttestation(teeNonce)
}
func (e *EmptyProvider) TeeAttestation(teeNonce []byte) ([]byte, error) {
return cocosai.EmbeddedAttestation, nil
// EmptyProvider should not be used for attestation
// The CC Attestation Agent's sample attester should be used instead
return nil, fmt.Errorf("EmptyProvider should not be used - configure USE_CC_ATTESTATION_AGENT=true to use the CC Attestation Agent's sample attester")
}
func (e *EmptyProvider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
+163
View File
@@ -0,0 +1,163 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cocosai "github.com/ultravioletrs/cocos"
)
func TestEmptyProvider_Attestation(t *testing.T) {
tests := []struct {
name string
teeNonce []byte
vTpmNonce []byte
wantErr bool
}{
{
name: "should return error for empty nonces",
teeNonce: []byte{},
vTpmNonce: []byte{},
wantErr: true,
},
{
name: "should return error for valid nonces",
teeNonce: make([]byte, 64),
vTpmNonce: make([]byte, 32),
wantErr: true,
},
{
name: "should return error for nil nonces",
teeNonce: nil,
vTpmNonce: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &EmptyProvider{}
got, err := p.Attestation(tt.teeNonce, tt.vTpmNonce)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
assert.Contains(t, err.Error(), "EmptyProvider should not be used")
} else {
assert.NoError(t, err)
assert.NotNil(t, got)
}
})
}
}
func TestEmptyProvider_TeeAttestation(t *testing.T) {
tests := []struct {
name string
teeNonce []byte
wantErr bool
}{
{
name: "should return error for empty nonce",
teeNonce: []byte{},
wantErr: true,
},
{
name: "should return error for valid nonce",
teeNonce: make([]byte, 64),
wantErr: true,
},
{
name: "should return error for nil nonce",
teeNonce: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &EmptyProvider{}
got, err := p.TeeAttestation(tt.teeNonce)
assert.Error(t, err)
assert.Nil(t, got)
assert.Contains(t, err.Error(), "EmptyProvider should not be used")
})
}
}
func TestEmptyProvider_VTpmAttestation(t *testing.T) {
tests := []struct {
name string
vTpmNonce []byte
wantErr bool
}{
{
name: "should return embedded attestation for empty nonce",
vTpmNonce: []byte{},
wantErr: false,
},
{
name: "should return embedded attestation for valid nonce",
vTpmNonce: make([]byte, 32),
wantErr: false,
},
{
name: "should return embedded attestation for nil nonce",
vTpmNonce: nil,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &EmptyProvider{}
got, err := p.VTpmAttestation(tt.vTpmNonce)
require.NoError(t, err)
assert.Equal(t, cocosai.EmbeddedAttestation, got)
})
}
}
func TestEmptyProvider_AzureAttestationToken(t *testing.T) {
tests := []struct {
name string
nonce []byte
wantErr bool
}{
{
name: "should return nil for empty nonce",
nonce: []byte{},
wantErr: false,
},
{
name: "should return nil for valid nonce",
nonce: make([]byte, 32),
wantErr: false,
},
{
name: "should return nil for nil nonce",
nonce: nil,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &EmptyProvider{}
got, err := p.AzureAttestationToken(tt.nonce)
require.NoError(t, err)
assert.Nil(t, got)
})
}
}
func TestEmptyProvider_ImplementsProvider(t *testing.T) {
var _ Provider = (*EmptyProvider)(nil)
}
+3
View File
@@ -395,6 +395,9 @@ func getPCRValue(index int, algorithm tpm2.Algorithm) ([]byte, error) {
pcrValue, err := tpm2.ReadPCR(rwc, index, algorithm)
if err != nil {
if _, ok := ExternalTPM.(*DummyRWC); ok {
return make([]byte, 20), nil
}
return nil, err
}
@@ -0,0 +1,70 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation_agent
import (
"context"
"fmt"
"strings"
"time"
aa "github.com/ultravioletrs/cocos/internal/proto/attestation-agent"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Client provides access to attestation-agent services.
type Client interface {
// GetToken gets a token from the attestation-agent (e.g., KBS token).
GetToken(ctx context.Context, tokenType string) ([]byte, error)
Close() error
}
type client struct {
conn *grpc.ClientConn
client aa.AttestationAgentServiceClient
}
// NewClient creates a new attestation-agent client.
// address can be either a TCP address (e.g., "127.0.0.1:50002") or Unix socket path (e.g., "/run/aa.sock").
func NewClient(address string) (Client, error) {
var target string
// If address contains ":", it's a TCP address, otherwise it's a Unix socket
if strings.Contains(address, ":") {
target = address
} else {
target = "unix://" + address
}
conn, err := grpc.NewClient(target, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, fmt.Errorf("failed to connect to attestation-agent: %w", err)
}
return &client{
conn: conn,
client: aa.NewAttestationAgentServiceClient(conn),
}, nil
}
func (c *client) Close() error {
return c.conn.Close()
}
// GetToken gets a token from the attestation-agent.
// tokenType should be "kbs" for KBS tokens.
func (c *client) GetToken(ctx context.Context, tokenType string) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
req := &aa.GetTokenRequest{
TokenType: tokenType,
}
resp, err := c.client.GetToken(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get token from attestation-agent: %w", err)
}
return resp.Token, nil
}
@@ -0,0 +1,194 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation_agent
import (
"context"
"net"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
aa "github.com/ultravioletrs/cocos/internal/proto/attestation-agent"
"google.golang.org/grpc"
)
type mockAttestationAgentServer struct {
aa.UnimplementedAttestationAgentServiceServer
getTokenCalled bool
lastTokenType string
tokenErr error
tokenResponse []byte
}
func (m *mockAttestationAgentServer) GetToken(ctx context.Context, req *aa.GetTokenRequest) (*aa.GetTokenResponse, error) {
m.getTokenCalled = true
m.lastTokenType = req.TokenType
if m.tokenErr != nil {
return nil, m.tokenErr
}
return &aa.GetTokenResponse{Token: m.tokenResponse}, nil
}
func TestNewClientUnixSocket(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "aa-test.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationAgentServer{tokenResponse: []byte("mock-token")}
aa.RegisterAttestationAgentServiceServer(grpcServer, mockServer)
go func() { _ = grpcServer.Serve(listener) }()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
require.NotNil(t, client)
err = client.Close()
assert.NoError(t, err)
}
func TestNewClientTCPAddress(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationAgentServer{tokenResponse: []byte("mock-token")}
aa.RegisterAttestationAgentServiceServer(grpcServer, mockServer)
go func() { _ = grpcServer.Serve(listener) }()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(listener.Addr().String())
require.NoError(t, err)
require.NotNil(t, client)
err = client.Close()
assert.NoError(t, err)
}
func TestGetToken(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "aa-gettoken.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationAgentServer{tokenResponse: []byte("kbs-token-response")}
aa.RegisterAttestationAgentServiceServer(grpcServer, mockServer)
go func() { _ = grpcServer.Serve(listener) }()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
token, err := client.GetToken(ctx, "kbs")
require.NoError(t, err)
assert.Equal(t, []byte("kbs-token-response"), token)
assert.True(t, mockServer.getTokenCalled)
assert.Equal(t, "kbs", mockServer.lastTokenType)
}
func TestGetTokenError(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "aa-error.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationAgentServer{tokenErr: assert.AnError}
aa.RegisterAttestationAgentServiceServer(grpcServer, mockServer)
go func() { _ = grpcServer.Serve(listener) }()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
token, err := client.GetToken(ctx, "kbs")
assert.Error(t, err)
assert.Nil(t, token)
assert.Contains(t, err.Error(), "failed to get token from attestation-agent")
}
func TestGetTokenCanceledContext(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "aa-cancel.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationAgentServer{tokenResponse: []byte("token")}
aa.RegisterAttestationAgentServiceServer(grpcServer, mockServer)
go func() { _ = grpcServer.Serve(listener) }()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = client.GetToken(ctx, "kbs")
assert.Error(t, err)
}
func TestClientClose(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "aa-close.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationAgentServer{}
aa.RegisterAttestationAgentServiceServer(grpcServer, mockServer)
go func() { _ = grpcServer.Serve(listener) }()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
err = client.Close()
assert.NoError(t, err)
}
func TestClientInterface(t *testing.T) {
var _ Client = (*client)(nil)
}
+42
View File
@@ -4,6 +4,7 @@ package attestation
import (
"context"
"fmt"
"time"
attestation_v1 "github.com/ultravioletrs/cocos/internal/proto/attestation/v1"
@@ -14,6 +15,7 @@ import (
type Client interface {
GetAttestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error)
GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error)
GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error)
Close() error
}
@@ -57,6 +59,10 @@ func (c *client) GetAttestation(ctx context.Context, reportData [64]byte, nonce
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
}
// Debug: log platform type conversion
fmt.Printf("[ATTESTATION-CLIENT] Platform type conversion: agent=%v (%d) -> proto=%v (%d)\n",
attType, attType, platformType, platformType)
req := &attestation_v1.AttestationRequest{
ReportData: reportData[:],
Nonce: nonce[:],
@@ -71,6 +77,42 @@ func (c *client) GetAttestation(ctx context.Context, reportData [64]byte, nonce
return resp.EatToken, nil
}
// GetRawEvidence gets raw binary evidence (for KBS) instead of EAT token.
func (c *client) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
var platformType attestation_v1.PlatformType
switch attType {
case attestation.SNP:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP
case attestation.TDX:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_TDX
case attestation.VTPM:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_VTPM
case attestation.SNPvTPM:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM
default:
platformType = attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED
}
fmt.Printf("[ATTESTATION-CLIENT] Getting raw evidence: platform=%v (%d)\n",
attType, platformType)
req := &attestation_v1.AttestationRequest{
ReportData: reportData[:],
Nonce: nonce[:],
PlatformType: platformType,
}
resp, err := c.client.FetchRawEvidence(ctx, req)
if err != nil {
return nil, err
}
return resp.Evidence, nil
}
func (c *client) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
+190
View File
@@ -20,11 +20,13 @@ import (
type mockAttestationServer struct {
attestation_v1.UnimplementedAttestationServiceServer
fetchAttestationCalled bool
fetchRawEvidenceCalled bool
fetchAzureTokenCalled bool
lastReportData []byte
lastNonce []byte
lastPlatformType attestation_v1.PlatformType
attestationErr error
rawEvidenceErr error
azureTokenErr error
}
@@ -43,6 +45,21 @@ func (m *mockAttestationServer) FetchAttestation(ctx context.Context, req *attes
}, nil
}
func (m *mockAttestationServer) FetchRawEvidence(ctx context.Context, req *attestation_v1.AttestationRequest) (*attestation_v1.RawEvidenceResponse, error) {
m.fetchRawEvidenceCalled = true
m.lastReportData = req.ReportData
m.lastNonce = req.Nonce
m.lastPlatformType = req.PlatformType
if m.rawEvidenceErr != nil {
return nil, m.rawEvidenceErr
}
return &attestation_v1.RawEvidenceResponse{
Evidence: []byte("mock-raw-evidence"),
}, nil
}
func (m *mockAttestationServer) FetchAzureToken(ctx context.Context, req *attestation_v1.AzureTokenRequest) (*attestation_v1.AzureTokenResponse, error) {
m.fetchAzureTokenCalled = true
m.lastNonce = req.Nonce
@@ -390,3 +407,176 @@ func TestClientOperationsAfterClose(t *testing.T) {
_, err = client.GetAttestation(ctx, reportData, nonce, attestation.SNP)
assert.Error(t, err)
}
// TestGetRawEvidenceSNP tests getting raw evidence for SNP platform.
func TestGetRawEvidenceSNP(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "raw-evidence-snp.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
copy(reportData[:], []byte("test-report-data"))
copy(nonce[:], []byte("test-nonce"))
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.SNP)
require.NoError(t, err)
assert.Equal(t, []byte("mock-raw-evidence"), evidence)
assert.True(t, mockServer.fetchRawEvidenceCalled)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_SNP, mockServer.lastPlatformType)
}
// TestGetRawEvidenceTDX tests getting raw evidence for TDX platform.
func TestGetRawEvidenceTDX(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "raw-evidence-tdx.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.TDX)
require.NoError(t, err)
assert.NotNil(t, evidence)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
}
// TestGetRawEvidenceVTPM tests getting raw evidence for VTPM platform.
func TestGetRawEvidenceVTPM(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "raw-evidence-vtpm.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.VTPM)
require.NoError(t, err)
assert.NotNil(t, evidence)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_VTPM, mockServer.lastPlatformType)
}
// TestGetRawEvidenceSNPvTPM tests getting raw evidence for SNPvTPM platform.
func TestGetRawEvidenceSNPvTPM(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "raw-evidence-snpvtpm.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.SNPvTPM)
require.NoError(t, err)
assert.NotNil(t, evidence)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM, mockServer.lastPlatformType)
}
// TestGetRawEvidenceUnspecified tests getting raw evidence with unspecified platform.
func TestGetRawEvidenceUnspecified(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "raw-evidence-unspec.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
evidence, err := client.GetRawEvidence(ctx, reportData, nonce, attestation.PlatformType(999))
require.NoError(t, err)
assert.NotNil(t, evidence)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED, mockServer.lastPlatformType)
}
+44 -6
View File
@@ -40,25 +40,63 @@ func (c *client) Close() error {
}
func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if entry.Timestamp == nil {
entry.Timestamp = timestamppb.Now()
}
// Retry with exponential backoff for concurrent request handling
maxRetries := 3
for attempt := 0; attempt < maxRetries; attempt++ {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
_, err := c.client.SendLog(ctx, entry)
cancel()
if err == nil {
return nil
}
// Don't retry on last attempt
if attempt < maxRetries-1 {
// Exponential backoff: 10ms, 20ms, 40ms
backoff := time.Duration(10*(1<<uint(attempt))) * time.Millisecond
time.Sleep(backoff)
}
}
// Return error after all retries exhausted
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
_, err := c.client.SendLog(ctx, entry)
return err
}
func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if entry.Timestamp == nil {
entry.Timestamp = timestamppb.Now()
}
// Retry with exponential backoff for concurrent request handling
maxRetries := 3
for attempt := 0; attempt < maxRetries; attempt++ {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
_, err := c.client.SendEvent(ctx, entry)
cancel()
if err == nil {
return nil
}
// Don't retry on last attempt
if attempt < maxRetries-1 {
// Exponential backoff: 10ms, 20ms, 40ms
backoff := time.Duration(10*(1<<uint(attempt))) * time.Millisecond
time.Sleep(backoff)
}
}
// Return error after all retries exhausted
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
_, err := c.client.SendEvent(ctx, entry)
return err
}
+254
View File
@@ -6,6 +6,7 @@ import (
"context"
"net"
"path/filepath"
"sync"
"testing"
"time"
@@ -330,3 +331,256 @@ func TestClientOperationsAfterClose(t *testing.T) {
err = client.SendLog(ctx, entry)
assert.Error(t, err)
}
// TestClientSendLogRetrySuccess tests SendLog retry behavior.
func TestClientSendLogRetrySuccess(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-retry-success.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
entry := &log.LogEntry{
Level: "INFO",
Message: "retry test",
}
err = client.SendLog(ctx, entry)
require.NoError(t, err)
assert.True(t, mockServer.sendLogCalled)
}
// TestClientSendEventRetrySuccess tests SendEvent retry behavior.
func TestClientSendEventRetrySuccess(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-event-retry.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
entry := &log.EventEntry{
EventType: "test.retry",
}
err = client.SendEvent(ctx, entry)
require.NoError(t, err)
}
// TestClientSendLogRetryWithFailures tests SendLog retry with intermittent failures.
func TestClientSendLogRetryWithFailures(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-retry-failures.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &retryMockLogCollectorServer{
failCount: 2, // Fail first 2 attempts
maxFailCount: 2,
}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
entry := &log.LogEntry{
Level: "INFO",
Message: "retry test",
}
// With retry logic, this should succeed on 3rd attempt
err = client.SendLog(ctx, entry)
require.NoError(t, err)
assert.Equal(t, 3, mockServer.callCount)
}
// TestClientSendEventRetryWithFailures tests SendEvent retry with intermittent failures.
func TestClientSendEventRetryWithFailures(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "event-retry-failures.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &retryMockLogCollectorServer{
failCount: 2, // Fail first 2 attempts
maxFailCount: 2,
}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
entry := &log.EventEntry{
EventType: "test.retry",
}
// With retry logic, this should succeed on 3rd attempt
err = client.SendEvent(ctx, entry)
require.NoError(t, err)
assert.Equal(t, 3, mockServer.eventCallCount)
}
// TestClientSendLogAllRetriesFail tests SendLog when all retries fail.
func TestClientSendLogAllRetriesFail(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-all-fail.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &retryMockLogCollectorServer{
failCount: 10, // Fail all attempts
maxFailCount: 10,
}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
entry := &log.LogEntry{
Level: "ERROR",
Message: "will fail",
}
// Should fail after all retries
err = client.SendLog(ctx, entry)
assert.Error(t, err)
// 3 retries + 1 final attempt = 4 calls
assert.Equal(t, 4, mockServer.callCount)
}
func TestClientSendEventAllRetriesFail(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-event-all-fail.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &retryMockLogCollectorServer{
failCount: 10,
maxFailCount: 10,
}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
entry := &log.EventEntry{
EventType: "TestEvent",
}
// Should fail after all retries
err = client.SendEvent(ctx, entry)
assert.Error(t, err)
// 3 retries + 1 final attempt = 4 calls
assert.Equal(t, 4, mockServer.eventCallCount)
}
// retryMockLogCollectorServer is a mock server that fails a specified number of times.
type retryMockLogCollectorServer struct {
log.UnimplementedLogCollectorServer
failCount int
maxFailCount int
callCount int
eventCallCount int
mu sync.Mutex
}
func (m *retryMockLogCollectorServer) SendLog(ctx context.Context, entry *log.LogEntry) (*emptypb.Empty, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.callCount <= m.maxFailCount {
return nil, assert.AnError
}
return &emptypb.Empty{}, nil
}
func (m *retryMockLogCollectorServer) SendEvent(ctx context.Context, entry *log.EventEntry) (*emptypb.Empty, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.eventCallCount++
if m.eventCallCount <= m.maxFailCount {
return nil, assert.AnError
}
return &emptypb.Empty{}, nil
}
+218
View File
@@ -0,0 +1,218 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"github.com/absmach/supermq/pkg/errors"
"golang.org/x/crypto/hkdf"
)
var (
// ErrDecryptionFailed indicates a decryption operation failed.
ErrDecryptionFailed = errors.New("decryption failed")
// ErrInvalidKey indicates the provided key is invalid.
ErrInvalidKey = errors.New("invalid decryption key")
// ErrInvalidCiphertext indicates the ciphertext is invalid or corrupted.
ErrInvalidCiphertext = errors.New("invalid ciphertext")
// ErrInvalidFormat indicates the encrypted resource format is invalid.
ErrInvalidFormat = errors.New("invalid encrypted resource format")
)
// EncryptedResource represents an encrypted resource from KBS.
// This matches the format used by Confidential Containers KBS.
type EncryptedResource struct {
// Ciphertext is the encrypted data.
Ciphertext []byte `json:"ciphertext"`
// EncryptedKey is the wrapped encryption key.
EncryptedKey []byte `json:"encrypted_key"`
// IV is the initialization vector for AES-GCM.
IV []byte `json:"iv"`
// Tag is the authentication tag for AES-GCM.
Tag []byte `json:"tag"`
// AAD is the additional authenticated data.
AAD []byte `json:"aad,omitempty"`
// EPK is the ephemeral public key for ECDH key derivation.
EPK *EphemeralPublicKey `json:"epk,omitempty"`
}
// EphemeralPublicKey represents an ephemeral EC P-256 public key.
type EphemeralPublicKey struct {
// Curve is the elliptic curve (should be "P-256").
Curve string `json:"crv"`
// X is the X coordinate of the public key.
X string `json:"x"`
// Y is the Y coordinate of the public key.
Y string `json:"y"`
}
// DecryptAESGCM decrypts data using AES-GCM with the provided key.
// This is used when the decryption key is provided directly (not wrapped).
func DecryptAESGCM(ciphertext, key, iv, tag, aad []byte) ([]byte, error) {
if len(key) != 16 && len(key) != 24 && len(key) != 32 {
return nil, errors.Wrap(ErrInvalidKey, errors.New("key must be 16, 24, or 32 bytes"))
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, errors.Wrap(ErrDecryptionFailed, err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, errors.Wrap(ErrDecryptionFailed, err)
}
// Combine ciphertext and tag for GCM
combined := append(ciphertext, tag...)
plaintext, err := aesgcm.Open(nil, iv, combined, aad)
if err != nil {
return nil, errors.Wrap(ErrDecryptionFailed, err)
}
return plaintext, nil
}
// DecryptWithWrappedKey decrypts data using a wrapped key and ECDH key derivation.
// This matches the KBS encryption format with ephemeral key exchange.
func DecryptWithWrappedKey(encryptedResource EncryptedResource, privateKey *ecdh.PrivateKey) ([]byte, error) {
if encryptedResource.EPK == nil {
return nil, errors.Wrap(ErrInvalidFormat, errors.New("ephemeral public key is required"))
}
// Decode ephemeral public key coordinates
xBytes, err := base64.RawURLEncoding.DecodeString(encryptedResource.EPK.X)
if err != nil {
return nil, errors.Wrap(ErrInvalidFormat, err)
}
yBytes, err := base64.RawURLEncoding.DecodeString(encryptedResource.EPK.Y)
if err != nil {
return nil, errors.Wrap(ErrInvalidFormat, err)
}
// Reconstruct ephemeral public key (uncompressed format: 0x04 || X || Y)
epkBytes := make([]byte, 1+len(xBytes)+len(yBytes))
epkBytes[0] = 0x04
copy(epkBytes[1:], xBytes)
copy(epkBytes[1+len(xBytes):], yBytes)
curve := ecdh.P256()
epk, err := curve.NewPublicKey(epkBytes)
if err != nil {
return nil, errors.Wrap(ErrInvalidFormat, err)
}
// Perform ECDH to derive shared secret
sharedSecret, err := privateKey.ECDH(epk)
if err != nil {
return nil, errors.Wrap(ErrDecryptionFailed, err)
}
// Derive KEK (Key Encryption Key) using HKDF
kek := make([]byte, 32)
kdf := hkdf.New(sha256.New, sharedSecret, nil, nil)
if _, err := kdf.Read(kek); err != nil {
return nil, errors.Wrap(ErrDecryptionFailed, err)
}
// Unwrap the content encryption key (CEK)
cek, err := unwrapKey(encryptedResource.EncryptedKey, kek)
if err != nil {
return nil, err
}
// Decrypt the actual content using the CEK
plaintext, err := DecryptAESGCM(
encryptedResource.Ciphertext,
cek,
encryptedResource.IV,
encryptedResource.Tag,
encryptedResource.AAD,
)
if err != nil {
return nil, err
}
// Zero out sensitive key material
zeroBytes(kek)
zeroBytes(cek)
zeroBytes(sharedSecret)
return plaintext, nil
}
// unwrapKey unwraps an encrypted key using AES Key Wrap (RFC 3394).
func unwrapKey(wrappedKey, kek []byte) ([]byte, error) {
if len(wrappedKey)%8 != 0 || len(wrappedKey) < 24 {
return nil, errors.Wrap(ErrInvalidKey, errors.New("wrapped key length must be a multiple of 8 and at least 24 bytes"))
}
block, err := aes.NewCipher(kek)
if err != nil {
return nil, errors.Wrap(ErrDecryptionFailed, err)
}
n := len(wrappedKey)/8 - 1
r := make([][]byte, n+1)
r[0] = wrappedKey[:8]
for i := 1; i <= n; i++ {
r[i] = wrappedKey[i*8 : (i+1)*8]
}
a := r[0]
for j := 5; j >= 0; j-- {
for i := n; i >= 1; i-- {
t := uint64(n*j + i)
b := make([]byte, 16)
for k := 0; k < 8; k++ {
b[k] = a[k] ^ byte(t>>(56-8*k))
}
copy(b[8:], r[i])
block.Decrypt(b, b)
a = b[:8]
r[i] = b[8:]
}
}
// Check integrity value
expectedIV := []byte{0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6}
for i := 0; i < 8; i++ {
if a[i] != expectedIV[i] {
return nil, errors.Wrap(ErrDecryptionFailed, errors.New("key unwrap integrity check failed"))
}
}
// Concatenate unwrapped key
unwrapped := make([]byte, 0, n*8)
for i := 1; i <= n; i++ {
unwrapped = append(unwrapped, r[i]...)
}
return unwrapped, nil
}
// ParseEncryptedResource parses a JSON-encoded encrypted resource.
func ParseEncryptedResource(data []byte) (*EncryptedResource, error) {
var resource EncryptedResource
if err := json.Unmarshal(data, &resource); err != nil {
return nil, errors.Wrap(ErrInvalidFormat, err)
}
return &resource, nil
}
// zeroBytes securely zeros out a byte slice.
func zeroBytes(b []byte) {
for i := range b {
b[i] = 0
}
}
+712
View File
@@ -0,0 +1,712 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/hkdf"
)
// testAESKeyWrap implements RFC 3394 AES Key Wrap for use in test setup.
func testAESKeyWrap(kek, key []byte) ([]byte, error) {
block, err := aes.NewCipher(kek)
if err != nil {
return nil, err
}
n := len(key) / 8
a := []byte{0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6}
r := make([][]byte, n+1)
for i := 1; i <= n; i++ {
r[i] = make([]byte, 8)
copy(r[i], key[(i-1)*8:i*8])
}
for j := 0; j <= 5; j++ {
for i := 1; i <= n; i++ {
t := uint64(n*j + i)
b := make([]byte, 16)
copy(b[:8], a)
copy(b[8:], r[i])
block.Encrypt(b, b)
for k := 0; k < 8; k++ {
a[k] = b[k] ^ byte(t>>(56-8*k))
}
r[i] = make([]byte, 8)
copy(r[i], b[8:])
}
}
result := make([]byte, (n+1)*8)
copy(result[:8], a)
for i := 1; i <= n; i++ {
copy(result[i*8:(i+1)*8], r[i])
}
return result, nil
}
func TestDecryptAESGCM(t *testing.T) {
// Generate a valid key
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
// Generate valid plaintext
plaintext := []byte("test plaintext data")
// Create cipher and encrypt
block, err := aes.NewCipher(key)
require.NoError(t, err)
aesgcm, err := cipher.NewGCM(block)
require.NoError(t, err)
iv := make([]byte, aesgcm.NonceSize())
_, err = rand.Read(iv)
require.NoError(t, err)
aad := []byte("additional data")
ciphertext := aesgcm.Seal(nil, iv, plaintext, aad)
// Split ciphertext and tag
tag := ciphertext[len(ciphertext)-aesgcm.Overhead():]
ciphertextOnly := ciphertext[:len(ciphertext)-aesgcm.Overhead()]
tests := []struct {
name string
ciphertext []byte
key []byte
iv []byte
tag []byte
aad []byte
wantErr bool
errContain string
}{
{
name: "valid decryption",
ciphertext: ciphertextOnly,
key: key,
iv: iv,
tag: tag,
aad: aad,
wantErr: false,
},
{
name: "invalid key length",
ciphertext: ciphertextOnly,
key: []byte("short"),
iv: iv,
tag: tag,
aad: aad,
wantErr: true,
errContain: "key must be 16, 24, or 32 bytes",
},
{
name: "wrong key",
ciphertext: ciphertextOnly,
key: make([]byte, 32),
iv: iv,
tag: tag,
aad: aad,
wantErr: true,
errContain: "decryption failed",
},
{
name: "corrupted tag",
ciphertext: ciphertextOnly,
key: key,
iv: iv,
tag: make([]byte, len(tag)),
aad: aad,
wantErr: true,
errContain: "decryption failed",
},
{
name: "wrong aad",
ciphertext: ciphertextOnly,
key: key,
iv: iv,
tag: tag,
aad: []byte("wrong aad"),
wantErr: true,
errContain: "decryption failed",
},
{
name: "16 byte key",
ciphertext: nil,
key: make([]byte, 16),
iv: nil,
tag: nil,
aad: nil,
wantErr: false,
},
{
name: "24 byte key",
ciphertext: nil,
key: make([]byte, 24),
iv: nil,
tag: nil,
aad: nil,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// For tests with nil ciphertext, create new cipher with specified key
if tt.ciphertext == nil && !tt.wantErr {
_, err := rand.Read(tt.key)
require.NoError(t, err)
block, err := aes.NewCipher(tt.key)
require.NoError(t, err)
aesgcm, err := cipher.NewGCM(block)
require.NoError(t, err)
tt.iv = make([]byte, aesgcm.NonceSize())
_, err = rand.Read(tt.iv)
require.NoError(t, err)
tt.aad = []byte("test aad")
ciphertext := aesgcm.Seal(nil, tt.iv, plaintext, tt.aad)
tt.tag = ciphertext[len(ciphertext)-aesgcm.Overhead():]
tt.ciphertext = ciphertext[:len(ciphertext)-aesgcm.Overhead()]
}
got, err := DecryptAESGCM(tt.ciphertext, tt.key, tt.iv, tt.tag, tt.aad)
if tt.wantErr {
assert.Error(t, err)
if tt.errContain != "" {
assert.Contains(t, err.Error(), tt.errContain)
}
} else {
require.NoError(t, err)
assert.Equal(t, plaintext, got)
}
})
}
}
func TestParseEncryptedResource(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr bool
}{
{
name: "valid encrypted resource",
data: func() []byte {
resource := EncryptedResource{
Ciphertext: []byte("encrypted data"),
EncryptedKey: []byte("wrapped key"),
IV: []byte("initialization vector"),
Tag: []byte("auth tag"),
AAD: []byte("additional data"),
}
data, err := json.Marshal(resource)
if err != nil {
panic(err)
}
return data
}(),
wantErr: false,
},
{
name: "valid encrypted resource with EPK",
data: func() []byte {
resource := EncryptedResource{
Ciphertext: []byte("encrypted data"),
EncryptedKey: []byte("wrapped key"),
IV: []byte("initialization vector"),
Tag: []byte("auth tag"),
EPK: &EphemeralPublicKey{
Curve: "P-256",
X: "AAAA",
Y: "BBBB",
},
}
data, err := json.Marshal(resource)
if err != nil {
panic(err)
}
return data
}(),
wantErr: false,
},
{
name: "invalid JSON",
data: []byte("not valid json"),
wantErr: true,
},
{
name: "empty JSON",
data: []byte("{}"),
wantErr: false,
},
{
name: "empty data",
data: []byte{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseEncryptedResource(tt.data)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
} else {
require.NoError(t, err)
assert.NotNil(t, got)
}
})
}
}
func TestZeroBytes(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{
name: "zero empty slice",
input: []byte{},
},
{
name: "zero small slice",
input: []byte{1, 2, 3, 4, 5},
},
{
name: "zero large slice",
input: make([]byte, 1024),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Fill with non-zero values
for i := range tt.input {
tt.input[i] = byte(i + 1)
}
zeroBytes(tt.input)
// Verify all bytes are zero
for i, b := range tt.input {
assert.Equal(t, byte(0), b, "byte at index %d should be 0", i)
}
})
}
}
func TestDecryptWithWrappedKey(t *testing.T) {
tests := []struct {
name string
encryptedResource EncryptedResource
privateKey *ecdh.PrivateKey
wantErr bool
errContain string
}{
{
name: "missing ephemeral public key",
encryptedResource: EncryptedResource{
Ciphertext: []byte("test"),
EncryptedKey: []byte("key"),
IV: []byte("iv"),
Tag: []byte("tag"),
EPK: nil,
},
privateKey: nil,
wantErr: true,
errContain: "ephemeral public key is required",
},
{
name: "invalid X coordinate encoding",
encryptedResource: EncryptedResource{
Ciphertext: []byte("test"),
EncryptedKey: []byte("key"),
IV: []byte("iv"),
Tag: []byte("tag"),
EPK: &EphemeralPublicKey{
Curve: "P-256",
X: "!!!invalid base64!!!",
Y: "AAAA",
},
},
privateKey: nil,
wantErr: true,
errContain: "invalid encrypted resource format",
},
{
name: "invalid Y coordinate encoding",
encryptedResource: EncryptedResource{
Ciphertext: []byte("test"),
EncryptedKey: []byte("key"),
IV: []byte("iv"),
Tag: []byte("tag"),
EPK: &EphemeralPublicKey{
Curve: "P-256",
X: base64.RawURLEncoding.EncodeToString(make([]byte, 32)),
Y: "!!!invalid base64!!!",
},
},
privateKey: nil,
wantErr: true,
errContain: "invalid encrypted resource format",
},
{
name: "invalid public key bytes",
encryptedResource: EncryptedResource{
Ciphertext: []byte("test"),
EncryptedKey: []byte("key"),
IV: []byte("iv"),
Tag: []byte("tag"),
EPK: &EphemeralPublicKey{
Curve: "P-256",
X: base64.RawURLEncoding.EncodeToString([]byte("short")),
Y: base64.RawURLEncoding.EncodeToString([]byte("short")),
},
},
privateKey: func() *ecdh.PrivateKey {
key, _ := ecdh.P256().GenerateKey(rand.Reader)
return key
}(),
wantErr: true,
errContain: "invalid encrypted resource format",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := DecryptWithWrappedKey(tt.encryptedResource, tt.privateKey)
if tt.wantErr {
assert.Error(t, err)
if tt.errContain != "" {
assert.Contains(t, err.Error(), tt.errContain)
}
} else {
require.NoError(t, err)
assert.NotNil(t, got)
}
})
}
}
func TestUnwrapKey(t *testing.T) {
tests := []struct {
name string
wrappedKey []byte
kek []byte
wantErr bool
errContain string
}{
{
name: "wrapped key too short",
wrappedKey: []byte("short"),
kek: make([]byte, 32),
wantErr: true,
errContain: "wrapped key length must be a multiple of 8 and at least 24 bytes",
},
{
name: "wrapped key not multiple of 8",
wrappedKey: make([]byte, 25),
kek: make([]byte, 32),
wantErr: true,
errContain: "wrapped key length must be a multiple of 8 and at least 24 bytes",
},
{
name: "invalid kek length",
wrappedKey: make([]byte, 24),
kek: []byte("short"),
wantErr: true,
errContain: "decryption failed",
},
{
name: "integrity check failure",
wrappedKey: make([]byte, 24),
kek: make([]byte, 32),
wantErr: true,
errContain: "key unwrap integrity check failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := unwrapKey(tt.wrappedKey, tt.kek)
if tt.wantErr {
assert.Error(t, err)
if tt.errContain != "" {
assert.Contains(t, err.Error(), tt.errContain)
}
} else {
require.NoError(t, err)
assert.NotNil(t, got)
}
})
}
}
func TestEncryptedResourceStructure(t *testing.T) {
t.Run("EphemeralPublicKey JSON serialization", func(t *testing.T) {
epk := EphemeralPublicKey{
Curve: "P-256",
X: "test_x",
Y: "test_y",
}
data, err := json.Marshal(epk)
require.NoError(t, err)
var decoded EphemeralPublicKey
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, epk.Curve, decoded.Curve)
assert.Equal(t, epk.X, decoded.X)
assert.Equal(t, epk.Y, decoded.Y)
})
t.Run("EncryptedResource JSON serialization", func(t *testing.T) {
resource := EncryptedResource{
Ciphertext: []byte("ciphertext"),
EncryptedKey: []byte("encrypted_key"),
IV: []byte("iv"),
Tag: []byte("tag"),
AAD: []byte("aad"),
EPK: &EphemeralPublicKey{
Curve: "P-256",
X: "x_coord",
Y: "y_coord",
},
}
data, err := json.Marshal(resource)
require.NoError(t, err)
var decoded EncryptedResource
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, resource.Ciphertext, decoded.Ciphertext)
assert.Equal(t, resource.EncryptedKey, decoded.EncryptedKey)
assert.Equal(t, resource.IV, decoded.IV)
assert.Equal(t, resource.Tag, decoded.Tag)
assert.Equal(t, resource.AAD, decoded.AAD)
assert.NotNil(t, decoded.EPK)
assert.Equal(t, resource.EPK.Curve, decoded.EPK.Curve)
})
}
func TestDecryptWithWrappedKeyFullRoundTrip(t *testing.T) {
t.Run("full ECDH + key wrap + AES-GCM round trip", func(t *testing.T) {
// Generate recipient private key (who will decrypt)
recipientKey, err := ecdh.P256().GenerateKey(rand.Reader)
require.NoError(t, err)
// Generate ephemeral key pair (used to encrypt)
ephemeralKey, err := ecdh.P256().GenerateKey(rand.Reader)
require.NoError(t, err)
// Compute shared secret: ephemeral_private ECDH recipient_public
sharedSecret, err := ephemeralKey.ECDH(recipientKey.PublicKey())
require.NoError(t, err)
// Derive KEK using HKDF (same as in DecryptWithWrappedKey)
kek := make([]byte, 32)
kdf := hkdf.New(sha256.New, sharedSecret, nil, nil)
_, err = kdf.Read(kek)
require.NoError(t, err)
// Generate random CEK (32 bytes)
cek := make([]byte, 32)
_, err = rand.Read(cek)
require.NoError(t, err)
// Wrap CEK using AES Key Wrap (RFC 3394)
wrappedKey, err := testAESKeyWrap(kek, cek)
require.NoError(t, err)
// Encrypt plaintext with AES-GCM using CEK
plaintext := []byte("hello world secret message for testing")
blk, err := aes.NewCipher(cek)
require.NoError(t, err)
aesgcm, err := cipher.NewGCM(blk)
require.NoError(t, err)
iv := make([]byte, aesgcm.NonceSize())
_, err = rand.Read(iv)
require.NoError(t, err)
// Go's Seal returns ciphertext || tag
combined := aesgcm.Seal(nil, iv, plaintext, nil)
ciphertext := combined[:len(combined)-aesgcm.Overhead()]
tag := combined[len(combined)-aesgcm.Overhead():]
// Get ephemeral public key coordinates (uncompressed: 0x04 || X(32) || Y(32))
epkPubBytes := ephemeralKey.PublicKey().Bytes()
xBytes := epkPubBytes[1:33]
yBytes := epkPubBytes[33:65]
resource := EncryptedResource{
Ciphertext: ciphertext,
EncryptedKey: wrappedKey,
IV: iv,
Tag: tag,
EPK: &EphemeralPublicKey{
Curve: "P-256",
X: base64.RawURLEncoding.EncodeToString(xBytes),
Y: base64.RawURLEncoding.EncodeToString(yBytes),
},
}
decrypted, err := DecryptWithWrappedKey(resource, recipientKey)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
})
t.Run("full round trip with AAD", func(t *testing.T) {
recipientKey, err := ecdh.P256().GenerateKey(rand.Reader)
require.NoError(t, err)
ephemeralKey, err := ecdh.P256().GenerateKey(rand.Reader)
require.NoError(t, err)
sharedSecret, err := ephemeralKey.ECDH(recipientKey.PublicKey())
require.NoError(t, err)
kek := make([]byte, 32)
kdf := hkdf.New(sha256.New, sharedSecret, nil, nil)
_, err = kdf.Read(kek)
require.NoError(t, err)
cek := make([]byte, 16) // 16-byte CEK (AES-128)
_, err = rand.Read(cek)
require.NoError(t, err)
wrappedKey, err := testAESKeyWrap(kek, cek)
require.NoError(t, err)
plaintext := []byte("confidential data with AAD")
aad := []byte("additional authenticated data")
blk, err := aes.NewCipher(cek)
require.NoError(t, err)
aesgcm, err := cipher.NewGCM(blk)
require.NoError(t, err)
iv := make([]byte, aesgcm.NonceSize())
_, err = rand.Read(iv)
require.NoError(t, err)
combined := aesgcm.Seal(nil, iv, plaintext, aad)
ciphertext := combined[:len(combined)-aesgcm.Overhead()]
tag := combined[len(combined)-aesgcm.Overhead():]
epkPubBytes := ephemeralKey.PublicKey().Bytes()
xBytes := epkPubBytes[1:33]
yBytes := epkPubBytes[33:65]
resource := EncryptedResource{
Ciphertext: ciphertext,
EncryptedKey: wrappedKey,
IV: iv,
Tag: tag,
AAD: aad,
EPK: &EphemeralPublicKey{
Curve: "P-256",
X: base64.RawURLEncoding.EncodeToString(xBytes),
Y: base64.RawURLEncoding.EncodeToString(yBytes),
},
}
decrypted, err := DecryptWithWrappedKey(resource, recipientKey)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
})
t.Run("wrong private key fails decryption", func(t *testing.T) {
recipientKey, err := ecdh.P256().GenerateKey(rand.Reader)
require.NoError(t, err)
wrongKey, err := ecdh.P256().GenerateKey(rand.Reader)
require.NoError(t, err)
ephemeralKey, err := ecdh.P256().GenerateKey(rand.Reader)
require.NoError(t, err)
sharedSecret, err := ephemeralKey.ECDH(recipientKey.PublicKey())
require.NoError(t, err)
kek := make([]byte, 32)
kdf := hkdf.New(sha256.New, sharedSecret, nil, nil)
_, err = kdf.Read(kek)
require.NoError(t, err)
cek := make([]byte, 32)
_, err = rand.Read(cek)
require.NoError(t, err)
wrappedKey, err := testAESKeyWrap(kek, cek)
require.NoError(t, err)
plaintext := []byte("secret")
blk, err := aes.NewCipher(cek)
require.NoError(t, err)
aesgcm, err := cipher.NewGCM(blk)
require.NoError(t, err)
iv := make([]byte, aesgcm.NonceSize())
_, err = rand.Read(iv)
require.NoError(t, err)
combined := aesgcm.Seal(nil, iv, plaintext, nil)
ciphertext := combined[:len(combined)-aesgcm.Overhead()]
tag := combined[len(combined)-aesgcm.Overhead():]
epkPubBytes := ephemeralKey.PublicKey().Bytes()
xBytes := epkPubBytes[1:33]
yBytes := epkPubBytes[33:65]
resource := EncryptedResource{
Ciphertext: ciphertext,
EncryptedKey: wrappedKey,
IV: iv,
Tag: tag,
EPK: &EphemeralPublicKey{
Curve: "P-256",
X: base64.RawURLEncoding.EncodeToString(xBytes),
Y: base64.RawURLEncoding.EncodeToString(yBytes),
},
}
// Using wrong key should fail
_, err = DecryptWithWrappedKey(resource, wrongKey)
assert.Error(t, err)
})
}
func TestErrorTypes(t *testing.T) {
t.Run("error constants are defined", func(t *testing.T) {
assert.NotNil(t, ErrDecryptionFailed)
assert.NotNil(t, ErrInvalidKey)
assert.NotNil(t, ErrInvalidCiphertext)
assert.NotNil(t, ErrInvalidFormat)
assert.Equal(t, "decryption failed", ErrDecryptionFailed.Error())
assert.Equal(t, "invalid decryption key", ErrInvalidKey.Error())
assert.Equal(t, "invalid ciphertext", ErrInvalidCiphertext.Error())
assert.Equal(t, "invalid encrypted resource format", ErrInvalidFormat.Error())
})
}
+342
View File
@@ -0,0 +1,342 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package oci
import (
"archive/tar"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
)
// OCILayout represents the OCI image layout.
type OCILayout struct {
ImageLayoutVersion string `json:"imageLayoutVersion"`
}
// OCIIndex represents the OCI index.json.
type OCIIndex struct {
SchemaVersion int `json:"schemaVersion"`
Manifests []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
} `json:"manifests"`
}
// ExtractAlgorithm extracts the algorithm file from an OCI image directory.
func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath string) (string, error) {
// Read index.json to find manifest
indexPath := filepath.Join(ociDir, "index.json")
indexData, err := os.ReadFile(indexPath)
if err != nil {
return "", fmt.Errorf("failed to read index.json: %w", err)
}
var index OCIIndex
if err := json.Unmarshal(indexData, &index); err != nil {
return "", fmt.Errorf("failed to parse index.json: %w", err)
}
if len(index.Manifests) == 0 {
return "", fmt.Errorf("no manifests found in index.json")
}
// Get the first manifest digest
manifestDigest := index.Manifests[0].Digest
manifestPath := filepath.Join(ociDir, "blobs", strings.Replace(manifestDigest, ":", "/", 1))
// Read manifest to find layers
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
return "", fmt.Errorf("failed to read manifest: %w", err)
}
var manifest struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return "", fmt.Errorf("failed to parse manifest: %w", err)
}
// Extract layers to find algorithm files
logger.Debug("found layers in manifest", "count", len(manifest.Layers))
var allSeenFiles []string
// Iterate layers in reverse order to find user code first (usually in top layers)
for i := len(manifest.Layers) - 1; i >= 0; i-- {
layer := manifest.Layers[i]
layerPath := filepath.Join(ociDir, "blobs", strings.Replace(layer.Digest, ":", "/", 1))
// Try to extract and find algorithm file
algoPath, seenFiles, err := extractLayerAndFindAlgorithm(logger, layerPath, destPath)
if len(seenFiles) > 0 {
allSeenFiles = append(allSeenFiles, seenFiles...)
}
if err != nil {
logger.Warn(fmt.Sprintf("error extracting layer %s: %v", layer.Digest, err))
continue
}
if algoPath != "" {
return algoPath, nil
}
}
return "", fmt.Errorf("no algorithm file found in OCI image layers (seen: %v)", allSeenFiles)
}
// extractLayerAndFindAlgorithm extracts a layer and searches for algorithm files.
func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath string) (string, []string, error) {
// Open layer file
layerFile, err := os.Open(layerPath)
if err != nil {
return "", nil, fmt.Errorf("failed to open layer: %w", err)
}
defer layerFile.Close()
// Decompress gzip
gzReader, err := gzip.NewReader(layerFile)
if err != nil {
return "", nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
// Read tar archive
tarReader := tar.NewReader(gzReader)
var algorithmPath string
var seenFiles []string
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return "", seenFiles, fmt.Errorf("failed to read tar header: %w", err)
}
logger.Debug("inspecting file in layer", "name", header.Name, "type", header.Typeflag)
// Skip directories
if header.Typeflag == tar.TypeDir {
continue
}
seenFiles = append(seenFiles, header.Name)
// Check if this is an algorithm file or requirements.txt
isAlgo := isAlgorithmFile(header.Name)
isReq := filepath.Base(header.Name) == "requirements.txt"
if isAlgo || isReq {
// Extract to destination, preserving directory structure
// Clean the name to prevent path traversal
cleanName := filepath.Clean(header.Name)
if strings.HasPrefix(cleanName, "..") || strings.HasPrefix(cleanName, "/") {
continue
}
targetPath := filepath.Join(destPath, cleanName)
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return "", seenFiles, fmt.Errorf("failed to create dir: %w", err)
}
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return "", seenFiles, fmt.Errorf("failed to create file: %w", err)
}
if _, err := io.Copy(outFile, tarReader); err != nil {
outFile.Close()
return "", seenFiles, fmt.Errorf("failed to write file: %w", err)
}
outFile.Close()
if isAlgo {
algorithmPath = targetPath
}
// Continue scanning to extract other files (like requirements.txt)
}
}
return algorithmPath, seenFiles, nil
}
// isAlgorithmFile checks if a file is likely an algorithm file.
func isAlgorithmFile(filename string) bool {
// Common algorithm file extensions
algorithmExts := []string{".py", ".wasm", ".wat", ".js", ".sh"}
// Common algorithm file names
algorithmNames := []string{"algorithm", "main", "run", "execute"}
base := filepath.Base(filename)
baseLower := strings.ToLower(base)
// Check extensions
for _, ext := range algorithmExts {
if strings.HasSuffix(baseLower, ext) {
return true
}
}
// Check common names
for _, name := range algorithmNames {
if strings.Contains(baseLower, name) {
return true
}
}
return false
}
// ExtractDataset extracts dataset files from an OCI image directory.
func ExtractDataset(ociDir, destPath string) ([]string, error) {
// Similar to ExtractAlgorithm but extracts all data files
// Read index.json to find manifest
indexPath := filepath.Join(ociDir, "index.json")
indexData, err := os.ReadFile(indexPath)
if err != nil {
return nil, fmt.Errorf("failed to read index.json: %w", err)
}
var index OCIIndex
if err := json.Unmarshal(indexData, &index); err != nil {
return nil, fmt.Errorf("failed to parse index.json: %w", err)
}
if len(index.Manifests) == 0 {
return nil, fmt.Errorf("no manifests found in index.json")
}
// Get the first manifest digest
manifestDigest := index.Manifests[0].Digest
manifestPath := filepath.Join(ociDir, "blobs", strings.Replace(manifestDigest, ":", "/", 1))
// Read manifest to find layers
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("failed to read manifest: %w", err)
}
var manifest struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return nil, fmt.Errorf("failed to parse manifest: %w", err)
}
var datasetFiles []string
// Extract all layers and collect dataset files
// Iterate layers in reverse order to find user data first (usually in top layers)
for i := len(manifest.Layers) - 1; i >= 0; i-- {
layer := manifest.Layers[i]
layerPath := filepath.Join(ociDir, "blobs", strings.Replace(layer.Digest, ":", "/", 1))
files, err := extractLayerDataFiles(layerPath, destPath)
if err != nil {
slog.Warn("error extracting layer", "digest", layer.Digest, "error", err)
continue
}
datasetFiles = append(datasetFiles, files...)
}
if len(datasetFiles) == 0 {
return nil, fmt.Errorf("no dataset files found in OCI image layers")
}
return datasetFiles, nil
}
// extractLayerDataFiles extracts data files from a layer.
func extractLayerDataFiles(layerPath, destPath string) ([]string, error) {
layerFile, err := os.Open(layerPath)
if err != nil {
return nil, err
}
defer layerFile.Close()
gzReader, err := gzip.NewReader(layerFile)
if err != nil {
return nil, err
}
defer gzReader.Close()
tarReader := tar.NewReader(gzReader)
var extractedFiles []string
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
if header.Typeflag == tar.TypeDir {
continue
}
// Check if this is a data file
if isDataFile(header.Name) {
// Extract to destination, preserving directory structure
cleanName := filepath.Clean(header.Name)
if strings.HasPrefix(cleanName, "..") || strings.HasPrefix(cleanName, "/") {
continue
}
targetPath := filepath.Join(destPath, cleanName)
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return nil, err
}
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return nil, err
}
if _, err := io.Copy(outFile, tarReader); err != nil {
outFile.Close()
return nil, err
}
outFile.Close()
extractedFiles = append(extractedFiles, targetPath)
}
}
return extractedFiles, nil
}
// isDataFile checks if a file is likely a dataset file.
func isDataFile(filename string) bool {
dataExts := []string{".csv", ".json", ".txt", ".parquet", ".arrow", ".dat"}
baseLower := strings.ToLower(filepath.Base(filename))
for _, ext := range dataExts {
if strings.HasSuffix(baseLower, ext) {
return true
}
}
return false
}
+920
View File
@@ -0,0 +1,920 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package oci
import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"encoding/json"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testPythonScript = "print('hello')"
func TestIsAlgorithmFile(t *testing.T) {
tests := []struct {
name string
filename string
want bool
}{
{"Python file", "algorithm.py", true},
{"WASM file", "module.wasm", true},
{"WAT file", "module.wat", true},
{"JavaScript file", "script.js", true},
{"Shell script", "run.sh", true},
{"Main python file", "main.py", true},
{"Execute file", "execute.py", true},
{"Algorithm name in path", "src/algorithm_v2.py", true},
{"Random python file", "helper.py", true},
{"CSV data file", "data.csv", false},
{"JSON config file", "config.json", false},
{"Text file", "readme.txt", false},
{"Binary file", "data.bin", false},
{"Uppercase extension", "MAIN.PY", true},
{"Mixed case", "Algorithm.Py", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isAlgorithmFile(tt.filename)
assert.Equal(t, tt.want, got)
})
}
}
func TestIsDataFile(t *testing.T) {
tests := []struct {
name string
filename string
want bool
}{
{"CSV file", "data.csv", true},
{"JSON file", "config.json", true},
{"Text file", "readme.txt", true},
{"Parquet file", "data.parquet", true},
{"Arrow file", "data.arrow", true},
{"DAT file", "data.dat", true},
{"Python file", "script.py", false},
{"WASM file", "module.wasm", false},
{"Binary file", "data.bin", false},
{"Uppercase CSV", "DATA.CSV", true},
{"Nested path", "data/input/dataset.csv", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isDataFile(tt.filename)
assert.Equal(t, tt.want, got)
})
}
}
func TestExtractAlgorithm(t *testing.T) {
logger := slog.Default()
t.Run("missing index.json", func(t *testing.T) {
tempDir := t.TempDir()
_, err := ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read index.json")
})
t.Run("invalid index.json", func(t *testing.T) {
tempDir := t.TempDir()
err := os.WriteFile(filepath.Join(tempDir, "index.json"), []byte("not json"), 0o644)
require.NoError(t, err)
_, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse index.json")
})
t.Run("empty manifests", func(t *testing.T) {
tempDir := t.TempDir()
index := OCIIndex{SchemaVersion: 2}
data, _ := json.Marshal(index)
err := os.WriteFile(filepath.Join(tempDir, "index.json"), data, 0o644)
require.NoError(t, err)
_, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
assert.Error(t, err)
assert.Contains(t, err.Error(), "no manifests found")
})
t.Run("successful extraction", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "algorithm.py", testPythonScript)
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
require.NoError(t, err)
assert.NotEmpty(t, algoPath)
assert.Contains(t, algoPath, "algorithm.py")
})
}
func TestExtractDataset(t *testing.T) {
t.Run("missing index.json", func(t *testing.T) {
tempDir := t.TempDir()
_, err := ExtractDataset(tempDir, t.TempDir())
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read index.json")
})
t.Run("successful extraction", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "data.csv", "col1,col2\n1,2")
files, err := ExtractDataset(ociDir, destDir)
require.NoError(t, err)
assert.NotEmpty(t, files)
})
}
func TestExtractDatasetWithPathTraversal(t *testing.T) {
t.Run("path traversal skipped, valid file extracted", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
// Path traversal entry (should be skipped)
maliciousHdr := &tar.Header{
Name: "../../../tmp/evil.csv",
Mode: 0o644,
Size: int64(len("evil")),
}
require.NoError(t, tw.WriteHeader(maliciousHdr))
_, err = tw.Write([]byte("evil"))
require.NoError(t, err)
// Valid CSV file
csvContent := "col1,col2\n1,2"
csvHdr := &tar.Header{
Name: "data.csv",
Mode: 0o644,
Size: int64(len(csvContent)),
}
require.NoError(t, tw.WriteHeader(csvHdr))
_, err = tw.Write([]byte(csvContent))
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
files, err := ExtractDataset(ociDir, destDir)
require.NoError(t, err)
assert.Len(t, files, 1)
assert.Contains(t, files[0], "data.csv")
// Verify malicious file was NOT created outside destDir
_, err = os.Stat("/tmp/evil.csv")
assert.True(t, os.IsNotExist(err))
})
}
func TestExtractDatasetInvalidManifest(t *testing.T) {
t.Run("invalid manifest JSON", func(t *testing.T) {
ociDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), []byte("not json"), 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: 8}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractDataset(ociDir, t.TempDir())
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse manifest")
})
}
func TestExtractDatasetWithDirectory(t *testing.T) {
t.Run("layer with directory entries for dataset", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
// Directory entry
dirHdr := &tar.Header{
Name: "data/",
Mode: 0o755,
Typeflag: tar.TypeDir,
}
require.NoError(t, tw.WriteHeader(dirHdr))
// CSV inside directory
csvContent := "a,b\n1,2"
csvHdr := &tar.Header{
Name: "data/dataset.csv",
Mode: 0o644,
Size: int64(len(csvContent)),
}
require.NoError(t, tw.WriteHeader(csvHdr))
_, err = tw.Write([]byte(csvContent))
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
files, err := ExtractDataset(ociDir, destDir)
require.NoError(t, err)
require.Len(t, files, 1)
assert.Contains(t, files[0], "dataset.csv")
})
}
func TestExtractDatasetMissingManifest(t *testing.T) {
t.Run("manifest file not found", func(t *testing.T) {
ociDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:nonexistent", Size: 0}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractDataset(ociDir, t.TempDir())
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read manifest")
})
}
func TestOCILayoutStructure(t *testing.T) {
t.Run("OCILayout JSON serialization", func(t *testing.T) {
layout := OCILayout{ImageLayoutVersion: "1.0.0"}
data, err := json.Marshal(layout)
require.NoError(t, err)
var decoded OCILayout
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, layout.ImageLayoutVersion, decoded.ImageLayoutVersion)
})
}
func setupTestOCIImage(t *testing.T, filename, content string) (ociDir, destDir string) {
t.Helper()
ociDir = t.TempDir()
destDir = t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
hdr := &tar.Header{
Name: filename,
Mode: 0o644,
Size: int64(len(content)),
}
require.NoError(t, tw.WriteHeader(hdr))
_, err = tw.Write([]byte(content))
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, err := json.Marshal(manifest)
require.NoError(t, err)
manifestPath := filepath.Join(blobsDir, "manifest123")
require.NoError(t, os.WriteFile(manifestPath, manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{
MediaType: "application/vnd.oci.image.manifest.v1+json",
Digest: "sha256:manifest123",
Size: len(manifestData),
}},
}
indexData, err := json.Marshal(index)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
return ociDir, destDir
}
func TestExtractAlgorithmWithRequirements(t *testing.T) {
logger := slog.Default()
t.Run("extract algorithm with requirements.txt", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
// Add algorithm file
algoContent := testPythonScript
algoHdr := &tar.Header{
Name: "main.py",
Mode: 0o644,
Size: int64(len(algoContent)),
}
require.NoError(t, tw.WriteHeader(algoHdr))
_, err = tw.Write([]byte(algoContent))
require.NoError(t, err)
// Add requirements.txt
reqContent := "numpy==1.21.0\npandas==1.3.0"
reqHdr := &tar.Header{
Name: "requirements.txt",
Mode: 0o644,
Size: int64(len(reqContent)),
}
require.NoError(t, tw.WriteHeader(reqHdr))
_, err = tw.Write([]byte(reqContent))
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
// Create manifest and index
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, err := json.Marshal(manifest)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, err := json.Marshal(index)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
require.NoError(t, err)
assert.Contains(t, algoPath, "main.py")
// Verify requirements.txt was also extracted
reqPath := filepath.Join(destDir, "requirements.txt")
_, err = os.Stat(reqPath)
assert.NoError(t, err)
})
}
func TestExtractAlgorithmNoAlgoFile(t *testing.T) {
logger := slog.Default()
t.Run("no algorithm file in layers", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
// Add a non-algorithm file (e.g., just a readme)
readmeContent := "This is a readme"
readmeHdr := &tar.Header{
Name: "README.md",
Mode: 0o644,
Size: int64(len(readmeContent)),
}
require.NoError(t, tw.WriteHeader(readmeHdr))
_, err = tw.Write([]byte(readmeContent))
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
}
func TestExtractDatasetNoDataFiles(t *testing.T) {
t.Run("no data files in layers", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
// Add a python file (not a data file)
pyContent := testPythonScript
pyHdr := &tar.Header{
Name: "script.py",
Mode: 0o644,
Size: int64(len(pyContent)),
}
require.NoError(t, tw.WriteHeader(pyHdr))
_, err = tw.Write([]byte(pyContent))
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err = ExtractDataset(ociDir, destDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no dataset files found")
})
}
func TestExtractAlgorithmInvalidManifest(t *testing.T) {
logger := slog.Default()
t.Run("invalid manifest JSON", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
// Write invalid manifest
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), []byte("not json"), 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: 8}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse manifest")
})
}
func TestExtractAlgorithmMissingManifest(t *testing.T) {
logger := slog.Default()
t.Run("manifest file not found", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
// Don't create manifest file
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:missing123", Size: 8}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read manifest")
})
}
func TestExtractAlgorithmWithDirectory(t *testing.T) {
logger := slog.Default()
t.Run("layer with directory entries", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
// Add a directory entry
dirHdr := &tar.Header{
Name: "src/",
Mode: 0o755,
Typeflag: tar.TypeDir,
}
require.NoError(t, tw.WriteHeader(dirHdr))
// Add algorithm file in subdirectory
algoContent := testPythonScript
algoHdr := &tar.Header{
Name: "src/main.py",
Mode: 0o644,
Size: int64(len(algoContent)),
}
require.NoError(t, tw.WriteHeader(algoHdr))
_, err = tw.Write([]byte(algoContent))
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
require.NoError(t, err)
assert.Contains(t, algoPath, "main.py")
})
}
func TestExtractAlgorithmPathTraversal(t *testing.T) {
logger := slog.Default()
t.Run("path traversal attempt", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
// Add a file with path traversal attempt
maliciousContent := "malicious"
maliciousHdr := &tar.Header{
Name: "../../../etc/malicious.py",
Mode: 0o644,
Size: int64(len(maliciousContent)),
}
require.NoError(t, tw.WriteHeader(maliciousHdr))
_, err = tw.Write([]byte(maliciousContent))
require.NoError(t, err)
// Add a legit file
algoContent := testPythonScript
algoHdr := &tar.Header{
Name: "algorithm.py",
Mode: 0o644,
Size: int64(len(algoContent)),
}
require.NoError(t, tw.WriteHeader(algoHdr))
_, err = tw.Write([]byte(algoContent))
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
require.NoError(t, err)
assert.Contains(t, algoPath, "algorithm.py")
// Verify malicious file was NOT extracted outside destDir
_, err = os.Stat("/etc/malicious.py")
assert.True(t, os.IsNotExist(err))
})
}
func TestExtractAlgorithmErrorPathsAdditional(t *testing.T) {
logger := slog.Default()
t.Run("invalid layer gzip", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "main.py", "print('hello')")
// Corrupt the layer file
layerPath := filepath.Join(ociDir, "blobs", "sha256", "layer123")
err := os.WriteFile(layerPath, []byte("not gzip"), 0o644)
require.NoError(t, err)
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
t.Run("invalid tar formatting", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "main.py", "print('hello')")
layerPath := filepath.Join(ociDir, "blobs", "sha256", "layer123")
// Create a valid gzip but invalid tar
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
_, err := gw.Write([]byte("not a tar archive but it is gzipped"))
require.NoError(t, err)
gw.Close()
err = os.WriteFile(layerPath, buf.Bytes(), 0o644)
require.NoError(t, err)
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
t.Run("non-existent layer file", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:nonexistent"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
}
func TestExtractDatasetErrorPathsAdditional(t *testing.T) {
t.Run("invalid layer gzip", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "data.csv", "a,b,c")
layerPath := filepath.Join(ociDir, "blobs", "sha256", "layer123")
err := os.WriteFile(layerPath, []byte("not gzip"), 0o644)
require.NoError(t, err)
_, err = ExtractDataset(ociDir, destDir)
assert.Error(t, err)
})
t.Run("non-existent layer file", func(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:nonexistent"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractDataset(ociDir, destDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no dataset files found")
})
}
+115
View File
@@ -0,0 +1,115 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package oci
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
)
const (
// OCICryptKeyproviderConfig is the environment variable for ocicrypt config.
OCICryptKeyproviderConfig = "OCICRYPT_KEYPROVIDER_CONFIG"
// DefaultOCICryptConfig is the default path to ocicrypt config.
DefaultOCICryptConfig = "/etc/ocicrypt_keyprovider.conf"
// DecryptionKeyProvider is the decryption key provider for CoCo.
DecryptionKeyProvider = "provider:attestation-agent:cc_kbc::null"
)
// SkopeoClient wraps skopeo command-line operations.
type SkopeoClient struct {
skopeoPath string
workDir string
}
// NewSkopeoClient creates a new Skopeo client.
func NewSkopeoClient(workDir string) (*SkopeoClient, error) {
// Find skopeo binary
skopeoPath, err := exec.LookPath("skopeo")
if err != nil {
return nil, fmt.Errorf("skopeo not found in PATH: %w", err)
}
// Ensure work directory exists
if err := os.MkdirAll(workDir, 0o755); err != nil {
return nil, fmt.Errorf("failed to create work directory: %w", err)
}
return &SkopeoClient{
skopeoPath: skopeoPath,
workDir: workDir,
}, nil
}
// PullAndDecrypt pulls an OCI image and decrypts it if encrypted.
func (s *SkopeoClient) PullAndDecrypt(ctx context.Context, source ResourceSource, destDir string) error {
// Ensure destination directory exists
if err := os.MkdirAll(destDir, 0o755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
args := []string{"copy"}
// Add decryption key if image is encrypted
if source.Encrypted {
args = append(args, "--decryption-key", DecryptionKeyProvider)
}
// Add insecure policy for testing (TODO: use proper policy in production)
args = append(args, "--insecure-policy", "--src-tls-verify=false", "--dest-tls-verify=false")
// Source and destination
args = append(args, source.URI, "oci:"+destDir)
cmd := exec.CommandContext(ctx, s.skopeoPath, args...)
// Set OCICRYPT environment
cmd.Env = append(os.Environ(),
OCICryptKeyproviderConfig+"="+DefaultOCICryptConfig)
// Set working directory
cmd.Dir = s.workDir
// Capture output
// Debug: Print full command
fmt.Printf("executing skopeo command: %s %v\n", s.skopeoPath, args)
fmt.Printf("skopeo environment: %s\n", OCICryptKeyproviderConfig+"="+DefaultOCICryptConfig)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("skopeo copy failed: %w\nOutput: %s", err, string(output))
}
return nil
}
// Inspect inspects an OCI image and returns basic manifest information.
func (s *SkopeoClient) Inspect(ctx context.Context, imageRef string) (*ImageManifest, error) {
args := []string{"inspect", "--insecure-policy", "--tls-verify=false", imageRef}
cmd := exec.CommandContext(ctx, s.skopeoPath, args...)
cmd.Env = append(os.Environ(),
OCICryptKeyproviderConfig+"="+DefaultOCICryptConfig)
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("skopeo inspect failed: %w\nOutput: %s", err, string(output))
}
// For now, return basic info
// nolint:godox // TODO: Parse JSON output for detailed manifest info
return &ImageManifest{
Reference: imageRef,
}, nil
}
// GetLocalImagePath returns the path to a local OCI image directory.
func (s *SkopeoClient) GetLocalImagePath(name string) string {
return filepath.Join(s.workDir, name)
}
+185
View File
@@ -0,0 +1,185 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package oci
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewSkopeoClient(t *testing.T) {
t.Run("valid work directory", func(t *testing.T) {
workDir := t.TempDir()
client, err := NewSkopeoClient(workDir)
if err != nil && err.Error() == "skopeo not found in PATH: exec: \"skopeo\": executable file not found in $PATH" {
t.Skip("skopeo not installed, skipping test")
}
require.NoError(t, err)
assert.NotNil(t, client)
})
t.Run("new work directory", func(t *testing.T) {
workDir := filepath.Join(t.TempDir(), "new", "nested", "dir")
client, err := NewSkopeoClient(workDir)
if err != nil && err.Error() == "skopeo not found in PATH: exec: \"skopeo\": executable file not found in $PATH" {
t.Skip("skopeo not installed, skipping test")
}
require.NoError(t, err)
assert.NotNil(t, client)
})
}
func TestSkopeoClient_GetLocalImagePath(t *testing.T) {
workDir := t.TempDir()
client, err := NewSkopeoClient(workDir)
if err != nil {
t.Skip("skopeo not installed, skipping test")
}
tests := []struct {
name string
imgName string
expected string
}{
{"simple image name", "myimage", filepath.Join(workDir, "myimage")},
{"image with tag", "myimage:latest", filepath.Join(workDir, "myimage:latest")},
{"nested path", "registry/repo/image", filepath.Join(workDir, "registry/repo/image")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := client.GetLocalImagePath(tt.imgName)
assert.Equal(t, tt.expected, got)
})
}
}
func TestSkopeoClient_PullAndDecrypt(t *testing.T) {
workDir := t.TempDir()
client, err := NewSkopeoClient(workDir)
if err != nil {
t.Skip("skopeo not installed, skipping test")
}
t.Run("invalid source URI", func(t *testing.T) {
ctx := context.Background()
destDir := t.TempDir()
source := ResourceSource{
Type: ResourceTypeOCIImage,
URI: "invalid://not-a-valid-uri",
Encrypted: false,
}
err := client.PullAndDecrypt(ctx, source, destDir)
assert.Error(t, err)
})
t.Run("destination directory created", func(t *testing.T) {
ctx := context.Background()
destDir := filepath.Join(t.TempDir(), "new", "nested", "dest")
source := ResourceSource{
Type: ResourceTypeOCIImage,
URI: "invalid://test",
Encrypted: false,
}
_ = client.PullAndDecrypt(ctx, source, destDir)
_, err := os.Stat(destDir)
assert.NoError(t, err)
})
}
func TestSkopeoClient_Inspect(t *testing.T) {
workDir := t.TempDir()
client, err := NewSkopeoClient(workDir)
if err != nil {
t.Skip("skopeo not installed, skipping test")
}
t.Run("invalid image reference", func(t *testing.T) {
ctx := context.Background()
manifest, err := client.Inspect(ctx, "invalid://not-a-valid-ref")
assert.Error(t, err)
assert.Nil(t, manifest)
})
}
func TestResourceSource(t *testing.T) {
t.Run("ResourceType constants", func(t *testing.T) {
assert.Equal(t, ResourceType("oci-image"), ResourceTypeOCIImage)
})
t.Run("ResourceSource structure", func(t *testing.T) {
source := ResourceSource{
Type: ResourceTypeOCIImage,
URI: "docker://registry/repo:tag",
Encrypted: true,
KBSResourcePath: "default/key/algo-key",
}
assert.Equal(t, ResourceTypeOCIImage, source.Type)
assert.Equal(t, "docker://registry/repo:tag", source.URI)
assert.True(t, source.Encrypted)
assert.Equal(t, "default/key/algo-key", source.KBSResourcePath)
})
}
func TestImageManifest(t *testing.T) {
t.Run("ImageManifest structure", func(t *testing.T) {
manifest := ImageManifest{
Reference: "docker://registry/repo:tag",
Digest: "sha256:abc123",
Layers: []string{"sha256:layer1", "sha256:layer2"},
}
assert.Equal(t, "docker://registry/repo:tag", manifest.Reference)
assert.Equal(t, "sha256:abc123", manifest.Digest)
assert.Len(t, manifest.Layers, 2)
})
}
func TestSkopeoConstants(t *testing.T) {
assert.Equal(t, "OCICRYPT_KEYPROVIDER_CONFIG", OCICryptKeyproviderConfig)
assert.Equal(t, "/etc/ocicrypt_keyprovider.conf", DefaultOCICryptConfig)
assert.Equal(t, "provider:attestation-agent:cc_kbc::null", DecryptionKeyProvider)
}
func TestNewSkopeoClientUnwritableDir(t *testing.T) {
if os.Getuid() == 0 {
t.Skip("cannot test unwritable dir as root")
}
// Create a file where a directory is expected
tmpDir := t.TempDir()
blockingFile := filepath.Join(tmpDir, "blocking")
require.NoError(t, os.WriteFile(blockingFile, []byte("data"), 0o444))
// Try to create a client with workDir inside a file (not a dir)
_, err := NewSkopeoClient(filepath.Join(blockingFile, "subdir"))
assert.Error(t, err)
}
func TestSkopeoClientPullAndDecryptEncrypted(t *testing.T) {
workDir := t.TempDir()
client, err := NewSkopeoClient(workDir)
if err != nil {
t.Skip("skopeo not installed, skipping test")
}
t.Run("encrypted image uses decryption key flag", func(t *testing.T) {
ctx := context.Background()
destDir := t.TempDir()
// Encrypted source - skopeo call will fail but the --decryption-key arg is built
source := ResourceSource{
Type: ResourceTypeOCIImage,
URI: "docker://invalid.registry/nonexistent:latest",
Encrypted: true,
}
err := client.PullAndDecrypt(ctx, source, destDir)
// We expect an error (no such image) but the encrypted code path was exercised
assert.Error(t, err)
assert.Contains(t, err.Error(), "skopeo copy failed")
})
}
+40
View File
@@ -0,0 +1,40 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package oci
// ResourceType defines the type of OCI resource.
type ResourceType string
const (
// ResourceTypeOCIImage represents a standard OCI image.
ResourceTypeOCIImage ResourceType = "oci-image"
)
// ResourceSource defines the source of an OCI resource.
type ResourceSource struct {
// Type of resource (oci-image)
Type ResourceType `json:"type"`
// URI is the OCI image reference (e.g., "docker://registry/repo:tag")
URI string `json:"uri"`
// Encrypted indicates if the image is encrypted
Encrypted bool `json:"encrypted"`
// KBSResourcePath is the KBS resource path for the decryption key
// (e.g., "default/key/algo-key")
KBSResourcePath string `json:"kbs_resource_path,omitempty"`
}
// ImageManifest represents basic OCI image manifest information.
type ImageManifest struct {
// Reference is the original image reference
Reference string
// Digest is the image digest
Digest string
// Layers are the layer digests
Layers []string
}