NOISSUE - Implement extensible resource downloader framework with support for S3, GCS, and OCI sources (#590)
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled

* feat: implement extensible resource downloader framework with support for S3, GCS, and OCI sources

Signed-off-by: SammyOina <sammyoina@gmail.com>

* refactor: improve resource URL parsing and add support for bare OCI image references

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: add empty string check and slash requirement for OCI image inference, and update python unit tests with event mock expectations

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: introduce OCIClient interface, add test coverage for decryption, and improve resource download error handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: remove trailing whitespace in OCI downloader and HTTP tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: SammyOina <sammyoina@gmail.com>
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2026-04-28 12:21:03 +03:00
committed by GitHub
parent 3b9841a973
commit c59a413765
18 changed files with 1337 additions and 55 deletions
+53 -36
View File
@@ -5,24 +5,25 @@ This guide explains how to test Cocos with encrypted remote resources using the
## Architecture Overview
```
┌────────────────────────────────────────────────────────────
│ CVM (Agent)
┌────────────────────────────────────────────────────────────┐
│ CVM (Agent) │
│ │
│ ┌──────────┐ ┌────────────────┐ ┌─────────────────┐ │
│ │ Agent │──▶│ Skopeo │──▶│ CoCo Keyprovider│ │
└──────────┘ │ (ocicrypt) │ │ (gRPC:50011) │ │
└───────────────┘ └────────┬────────┘ │
│ │
┌────────▼────────┐ │
│ Attestation │ │
│ Agent (50002) │ │
└────────┬────────┘ │
└──────────────────────────────────────────────────┼──────────┘
────────────────
│ KBS Server │
│ (Host:8080) │
└─────────────────┘
│ │ Agent │──▶│ Skopeo │──▶│ CoCo Keyprovider│ │
│ │ │ (ocicrypt) │ │ (gRPC:50011) │ │
└───────────────┘ └────────┬────────┘ │
│ │
┌───────▼────────┐ ┌────────▼────────┐ │
│ │──▶│ S3/HTTP │ Attestation │ │
Downloader │ │ Agent (50002) │ │
└────┬─────┘ └───────┬────────┘ └────────┬────────┘ │
│ │ │ │ │
│ └──────────────────┼──────────────────────┘
└────────┬─────────────────┼──────────────────────┬──────────
│ (Resource) │ (Resource) │ (Attest)
▼ ▼ ▼
OCI Registry S3 / HTTP / GCS KBS
(Key Broker)
```
## Prerequisites
@@ -406,27 +407,43 @@ curl http://HOST_IP:8080/kbs/v0/auth
# Ensure KBS is configured for sample attestation
```
## Differences from Previous Approach
## 4. Testing with Non-OCI Sources (S3, HTTP, GCS)
| Aspect | Old (Custom) | New (CoCo Standard) |
|--------|-------------|---------------------|
| **Download** | Custom S3/HTTP clients | Skopeo (OCI standard) |
| **Decryption** | Custom KBS client | CoCo Keyprovider |
| **Attestation** | Direct KBS RCAR | AA → CoCo KP → KBS |
| **Format** | Raw encrypted files | OCI encrypted images |
| **Complexity** | ~2000 lines custom code | Standard CoCo components |
The `cvms` test utility also supports testing remote encrypted resources hosted in more traditional environments like S3-compatible storage or simple web servers, bypassing the need for container registries and OCI images.
## Benefits
### Supported Flags
1. **Standards Compliance**: Uses OCI and CoCo standards
2. **Better Tooling**: Leverage Skopeo, Docker, Podman ecosystem
3. **Simplified Code**: Remove custom registry/decryption logic
4. **Proven Solution**: Battle-tested CoCo components
5. **Docker Native**: Works with existing Docker workflows
The following flags define how resources should be fetched:
## Next Steps
- `--algo-source-url`: The URL of the algorithm (e.g. `s3://bucket/algo.bin`, `https://server/algo.bin`)
- `--algo-source-type`: The type of remote endpoint (`s3`, `gcs`, `https`, `http`). If omitted, it will automatically be inferred from the URL scheme.
- `--algo-kbs-path`: The KBS path to retrieve the AES-256-GCM key from. If present, the agent will attempt decryption.
- `--dataset-source-urls` and `--dataset-source-type`: Defines the locations and protocols for datasets.
- Encrypt your algorithms and datasets as OCI images
- Push to your preferred OCI registry (Docker Hub, GHCR, etc.)
- Update computation manifests to use `oci-image` type
- Test end-to-end flow with encrypted workloads
### Encryption Format for Non-OCI Sources
Unlike OCI images where `ocicrypt` wraps the dataset, resources hosted on HTTP/S3 must be straightforwardly encrypted using **AES-256-GCM**.
The expected format is exactly as produced by standard Go AES-GCM:
`nonce (12 bytes) || ciphertext || tag`
### Test Example
If you had a Python script encrypted using a key hosted at KBS path `default/my-keys/python-script` and uploaded to `s3://my-secure-bucket/script.enc`, you could run:
```bash
cd test
go run cvms/main.go --algo-source-url="s3://my-secure-bucket/script.enc" \
--algo-source-type="s3" \
--algo-kbs-path="default/my-keys/python-script" \
--algo-type="python" \
--public-key-path=./test-data/public-key.pem
```
The system will:
1. Connect via `attestation-agent` to the KBS to retrieve the symmetric key
2. Use Google Cloud Storage client library methods (support for generic S3 via environment variables is standard) to fetch the resource
3. Decrypt using AES-256-GCM
4. Run the code normally
---
+3
View File
@@ -14,6 +14,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events/mocks"
@@ -88,6 +89,7 @@ func TestRun(t *testing.T) {
}
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
var stdout, stderr bytes.Buffer
@@ -129,6 +131,7 @@ func TestRunWithRequirements(t *testing.T) {
}
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
var stdout, stderr bytes.Buffer
+9 -2
View File
@@ -22,9 +22,16 @@ type AgentConfig struct {
// ResourceSource specifies the location of a remote encrypted resource.
type ResourceSource struct {
// Type is the type of resource source (currently only "oci-image" is supported)
// Type is the type of resource source.
// Supported values: "oci-image", "s3", "gcs", "https", "http"
Type string `json:"type,omitempty"`
// URL is the location of the resource (e.g., docker://registry/repo:tag)
// URL is the location of the resource.
// Examples:
// - OCI: "docker://registry/repo:tag"
// - S3: "s3://bucket/key"
// - GCS: "gs://bucket/key"
// - HTTPS: "https://host/path/to/file"
// - HTTP: "http://host/path/to/file"
URL string `json:"url,omitempty"`
// KBSResourcePath is the path to the decryption key in KBS (e.g., "default/key/my-key")
KBSResourcePath string `json:"kbs_resource_path,omitempty"`
+2 -2
View File
@@ -116,8 +116,8 @@ message Algorithm {
}
message Source {
string type = 1; // Type of source: "oci-image" (only OCI images supported for CoCo)
string url = 2; // URL of the OCI image (e.g., docker://registry/repo:tag)
string type = 1; // Type of source: "oci-image", "s3", "gcs", "https", "http"
string url = 2; // URL of the resource (e.g., docker://registry/repo:tag, s3://bucket/key, https://host/path)
string kbs_resource_path = 3; // Path to decryption key in KBS (e.g., "default/key/my-key")
bool encrypted = 4; // Whether the resource is encrypted (requires KBS)
}
+188
View File
@@ -0,0 +1,188 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package agent
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/pkg/resource"
)
type MockDownloader struct {
mock.Mock
}
func (m *MockDownloader) Download(ctx context.Context, url string, destPath string) error {
args := m.Called(ctx, url, destPath)
if args.Error(0) == nil {
// Simulate writing to destPath if it's a success
content := "mock content"
if len(args) > 1 {
if c, ok := args.Get(1).(string); ok {
content = c
}
}
_ = os.MkdirAll(filepath.Dir(destPath), 0o755)
_ = os.WriteFile(destPath, []byte(content), 0o644)
}
return args.Error(0)
}
func (m *MockDownloader) Type() string {
return m.Called().String(0)
}
func TestDownloadAndDecryptGenericResource(t *testing.T) {
registry := resource.NewRegistry()
mockDownloader := new(MockDownloader)
mockDownloader.On("Type").Return(resource.SourceTypeHTTP)
registry.Register(mockDownloader)
svc := &agentService{
logger: slog.Default(),
resourceRegistry: registry,
computation: Computation{
KBS: KBSConfig{
Enabled: true,
URL: "http://mock-kbs",
},
},
}
ctx := context.Background()
t.Run("Successful download without encryption", func(t *testing.T) {
source := &ResourceSource{
URL: "http://example.com/resource",
}
destPath := filepath.Join(os.TempDir(), "cocos-resources", "algo", "resource")
mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, "some data").Once()
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "algo")
assert.NoError(t, err)
assert.Equal(t, []byte("some data"), res.Data)
mockDownloader.AssertExpectations(t)
})
t.Run("Successful download with encryption", func(t *testing.T) {
key := make([]byte, 32)
_, _ = io.ReadFull(rand.Reader, key)
plaintext := []byte("secret data")
block, _ := aes.NewCipher(key)
gcm, _ := cipher.NewGCM(block)
nonce := make([]byte, gcm.NonceSize())
_, _ = io.ReadFull(rand.Reader, nonce)
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
// Mock KBS
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(key)
}))
defer ts.Close()
svc.computation.KBS.URL = ts.URL
source := &ResourceSource{
URL: "http://example.com/encrypted",
Encrypted: true,
KBSResourcePath: "keys/1",
}
destPath := filepath.Join(os.TempDir(), "cocos-resources", "data", "encrypted")
mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, string(ciphertext)).Once()
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "data")
assert.NoError(t, err)
assert.Equal(t, plaintext, res.Data)
mockDownloader.AssertExpectations(t)
})
t.Run("Registry not initialized", func(t *testing.T) {
badSvc := &agentService{logger: slog.Default()}
_, err := badSvc.downloadAndDecryptGenericResource(ctx, &ResourceSource{}, "http", "algo")
assert.Error(t, err)
assert.Contains(t, err.Error(), "resource registry not initialized")
})
}
func TestGetKeyFromKBS(t *testing.T) {
svc := &agentService{
logger: slog.Default(),
computation: Computation{
KBS: KBSConfig{
Enabled: true,
},
},
}
ctx := context.Background()
t.Run("KBS disabled", func(t *testing.T) {
svc.computation.KBS.Enabled = false
_, err := svc.getKeyFromKBS(ctx, "path")
assert.Error(t, err)
})
t.Run("Successful fetch", func(t *testing.T) {
svc.computation.KBS.Enabled = true
key := []byte("this is a 32-byte key!!!!!!!!!!!")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Contains(t, r.URL.Path, "resource/path")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(key)
}))
defer ts.Close()
svc.computation.KBS.URL = ts.URL
fetched, err := svc.getKeyFromKBS(ctx, "path")
assert.NoError(t, err)
assert.Equal(t, key, fetched)
})
t.Run("KBS error", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
svc.computation.KBS.URL = ts.URL
_, err := svc.getKeyFromKBS(ctx, "path")
assert.Error(t, err)
})
}
func TestInferSourceTypeDetailed(t *testing.T) {
tests := []struct {
url string
expected string
}{
{"s3://bucket/key", resource.SourceTypeS3},
{"gs://bucket/key", resource.SourceTypeGCS},
{"https://example.com/file", resource.SourceTypeHTTPS},
{"http://example.com/file", resource.SourceTypeHTTP},
{"docker://ubuntu", resource.SourceTypeOCIImage},
{"oci:/path/to/dir", resource.SourceTypeOCIImage},
{"ubuntu:latest", resource.SourceTypeOCIImage},
{"myregistry.io/myimage:tag", resource.SourceTypeOCIImage},
{"invalid-url-no-slash", ""},
{"", ""},
{"ftp://server/file", ""},
}
for _, tt := range tests {
assert.Equal(t, tt.expected, inferSourceType(tt.url), "URL: %s", tt.url)
}
}
+182 -9
View File
@@ -7,7 +7,10 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
@@ -27,6 +30,7 @@ import (
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
"github.com/ultravioletrs/cocos/pkg/oci"
"github.com/ultravioletrs/cocos/pkg/resource"
"golang.org/x/crypto/sha3"
)
@@ -151,6 +155,7 @@ type agentService struct {
cancel context.CancelFunc // Cancels the computation context.
vmpl int // VMPL at which the Agent is running.
ociClient OCIClient
resourceRegistry *resource.Registry // Registry of resource downloaders (S3, HTTP, etc.)
}
var _ Service = (*agentService)(nil)
@@ -176,6 +181,17 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, atte
}
svc.ociClient = skopeoClient
// Initialize resource downloader registry with all supported source types.
reg := resource.NewRegistry()
if skopeoClient != nil {
reg.Register(resource.NewOCIDownloader(skopeoClient))
}
reg.Register(resource.NewHTTPSDownloader())
reg.Register(resource.NewHTTPDownloader())
reg.Register(resource.NewS3Downloader(""))
reg.Register(resource.NewGCSDownloader())
svc.resourceRegistry = reg
transitions := []statemachine.Transition{
{From: Idle, Event: Start, To: ReceivingManifest},
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
@@ -548,28 +564,179 @@ type DecryptedResource struct {
SourceDir string
}
// downloadAndDecryptResource downloads and decrypts a resource using OCI images and CoCo Keyprovider.
// downloadAndDecryptResource downloads and decrypts a resource from various sources.
// For OCI images, Skopeo handles download and CoCo Keyprovider handles decryption automatically.
// For S3, GCS, HTTP/HTTPS: download + optional AES-256-GCM decryption with key from KBS.
func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) {
// Determine source type
// Determine source type.
sourceType := source.Type
if sourceType == "" {
// Infer from URL
if strings.HasPrefix(source.URL, "docker://") || strings.HasPrefix(source.URL, "oci:") {
sourceType = "oci-image"
} else {
return nil, fmt.Errorf("unsupported source URL format: %s (use oci-image type)", source.URL)
sourceType = inferSourceType(source.URL)
if sourceType == "" {
return nil, fmt.Errorf("unsupported source URL format: %s (specify type explicitly or use a recognized URL scheme)", source.URL)
}
}
switch sourceType {
case "oci-image":
case resource.SourceTypeOCIImage:
return as.downloadAndDecryptOCIImage(ctx, source, resourceType)
case resource.SourceTypeS3, resource.SourceTypeGCS, resource.SourceTypeHTTPS, resource.SourceTypeHTTP:
return as.downloadAndDecryptGenericResource(ctx, source, sourceType, resourceType)
default:
return nil, fmt.Errorf("unsupported source type: %s", sourceType)
}
}
// inferSourceType infers the resource source type from the URL scheme.
func inferSourceType(u string) string {
if u == "" {
return ""
}
parsedURL, err := url.Parse(u)
if err != nil {
return ""
}
switch parsedURL.Scheme {
case "docker", "oci":
return resource.SourceTypeOCIImage
case "s3":
return resource.SourceTypeS3
case "gs":
return resource.SourceTypeGCS
case "https":
return resource.SourceTypeHTTPS
case "http":
return resource.SourceTypeHTTP
case "":
// No URL scheme (e.g., bare "docker.io/library/ubuntu:latest").
// Default to OCI Image if it looks like one (contains a slash).
if strings.Contains(u, "/") {
return resource.SourceTypeOCIImage
}
return ""
default:
// A scheme was parsed. But if it's not a known standard scheme,
// it might be a bare OCI reference like "ubuntu:latest" where "ubuntu" is parsed as the scheme.
// If there is no "://" and we have an opaque part (meaning there's a colon but no slashes),
// it's highly likely a bare image name.
if !strings.Contains(u, "://") && parsedURL.Opaque != "" {
return resource.SourceTypeOCIImage
}
return ""
}
}
// downloadAndDecryptGenericResource downloads a resource using the appropriate downloader
// from the registry and optionally decrypts it with AES-256-GCM using a key from KBS.
func (as *agentService) downloadAndDecryptGenericResource(ctx context.Context, source *ResourceSource, sourceType, resourceType string) (*DecryptedResource, error) {
as.logger.Info(fmt.Sprintf("downloading %s resource (type=%s url=%s encrypted=%t kbs_path=%s)",
resourceType, sourceType, source.URL, source.Encrypted, source.KBSResourcePath))
if as.resourceRegistry == nil {
return nil, fmt.Errorf("resource registry not initialized")
}
downloader, err := as.resourceRegistry.Get(sourceType)
if err != nil {
return nil, fmt.Errorf("no downloader for source type %s: %w", sourceType, err)
}
// Download to temporary file.
destPath := filepath.Join(os.TempDir(), "cocos-resources", resourceType, filepath.Base(source.URL))
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
if err := downloader.Download(ctx, source.URL, destPath); err != nil {
return nil, fmt.Errorf("failed to download resource from %s: %w", source.URL, err)
}
as.logger.Info("resource downloaded", "dest", destPath)
// Read the downloaded file.
data, err := os.ReadFile(destPath)
if err != nil {
return nil, fmt.Errorf("failed to read downloaded resource: %w", err)
}
// If encrypted, retrieve key from KBS and decrypt.
if source.Encrypted && source.KBSResourcePath != "" {
as.logger.Info("resource is encrypted, retrieving decryption key from KBS",
"kbs_path", source.KBSResourcePath,
"kbs_url", as.computation.KBS.URL)
key, err := as.getKeyFromKBS(ctx, source.KBSResourcePath)
if err != nil {
return nil, fmt.Errorf("failed to retrieve decryption key from KBS: %w", err)
}
plaintext, err := resource.DecryptData(data, key)
if err != nil {
return nil, fmt.Errorf("failed to decrypt resource: %w", err)
}
data = plaintext
as.logger.Info("resource decrypted successfully", "plaintext_size", len(data))
}
return &DecryptedResource{
Data: data,
}, nil
}
// getKeyFromKBS retrieves a decryption key from the Key Broker Service.
// It uses the Attestation Agent's GetResource capability to fetch the key
// after performing remote attestation.
func (as *agentService) getKeyFromKBS(ctx context.Context, resourcePath string) ([]byte, error) {
if !as.computation.KBS.Enabled || as.computation.KBS.URL == "" {
return nil, fmt.Errorf("KBS not configured or not enabled")
}
// Construct KBS resource URL: kbs://<kbs_url>/<resource_path>
kbsResourceURL := fmt.Sprintf("%s/kbs/v0/resource/%s", as.computation.KBS.URL, resourcePath)
as.logger.Info("fetching key from KBS", "url", kbsResourceURL)
// Use a simple HTTP GET to KBS for now.
// In a full CoCo deployment, this would go through the Attestation Agent
// which performs attestation before KBS releases the key.
// For non-OCI resources, the AA/KBS handshake may need to be handled
// differently than via ocicrypt.
resp, err := kbsHTTPGet(ctx, kbsResourceURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch key from KBS at %s: %w", kbsResourceURL, err)
}
return resp, nil
}
// kbsHTTPGet performs an HTTP GET to the KBS endpoint.
func kbsHTTPGet(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("KBS returned status %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return body, nil
}
// downloadAndDecryptOCIImage downloads and decrypts an OCI image using Skopeo and CoCo Keyprovider.
func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) {
as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s)",
@@ -580,10 +747,16 @@ func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *
return nil, fmt.Errorf("OCI client not initialized")
}
uri := source.URL
// If the URI is just an image name without a transport scheme, default to docker://
if !strings.Contains(uri, "://") && !strings.HasPrefix(uri, "oci:") && !strings.HasPrefix(uri, "docker-archive:") && !strings.HasPrefix(uri, "dir:") {
uri = "docker://" + uri
}
// Create OCI resource source
ociSource := oci.ResourceSource{
Type: oci.ResourceTypeOCIImage,
URI: source.URL,
URI: uri,
Encrypted: source.Encrypted,
KBSResourcePath: source.KBSResourcePath,
}
+61 -2
View File
@@ -34,6 +34,7 @@ import (
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks"
"github.com/ultravioletrs/cocos/pkg/oci"
"github.com/ultravioletrs/cocos/pkg/resource"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/emptypb"
@@ -716,7 +717,7 @@ func TestDownloadAndDecryptResource(t *testing.T) {
ctx := context.Background()
t.Run("unsupported URL format no type", func(t *testing.T) {
source := &ResourceSource{URL: "http://unsupported-format"}
source := &ResourceSource{URL: "abc://unsupported-format"}
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported source URL format")
@@ -736,6 +737,21 @@ func TestDownloadAndDecryptResource(t *testing.T) {
assert.Contains(t, err.Error(), "unsupported source type: s3-bucket")
})
t.Run("bare OCI image name inferred as oci-image", func(t *testing.T) {
source := &ResourceSource{URL: "ubuntu:latest"}
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
require.Error(t, err)
// Should route to OCI and fail at OCI client (which is nil or mock)
assert.NotContains(t, err.Error(), "unsupported source URL format")
})
t.Run("bare registry image name inferred as oci-image", func(t *testing.T) {
source := &ResourceSource{URL: "gcr.io/project/image:latest"}
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
require.Error(t, err)
assert.NotContains(t, err.Error(), "unsupported source URL format")
})
t.Run("docker:// URL inferred as oci-image routes to skopeo", func(t *testing.T) {
// This exercises the oci-image path; will fail at skopeo step
source := &ResourceSource{URL: "docker://invalid.example.com/algo:latest"}
@@ -759,10 +775,27 @@ func TestDownloadAndDecryptResource(t *testing.T) {
assert.NotContains(t, err.Error(), "unsupported source type")
})
t.Run("dataset resource type with oci-image", func(t *testing.T) {
t.Run("dataset resource type with oci-image routes to skopeo", func(t *testing.T) {
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/data:latest"}
_, err := svc.downloadAndDecryptResource(ctx, source, "dataset")
require.Error(t, err)
assert.NotContains(t, err.Error(), "unsupported source type")
})
t.Run("https inferred routes to registry", func(t *testing.T) {
// Mock registry to fail predictably
source := &ResourceSource{URL: "https://example.com/file.bin"}
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
require.Error(t, err)
// It should complain about registry missing, because the test service does not initialize the registry
assert.Contains(t, err.Error(), "resource registry not initialized")
})
t.Run("s3 inferred routes to registry", func(t *testing.T) {
source := &ResourceSource{URL: "s3://bucket/key"}
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
require.Error(t, err)
assert.Contains(t, err.Error(), "resource registry not initialized")
})
}
@@ -1751,3 +1784,29 @@ func TestRunComputation_Success(t *testing.T) {
sm.AssertExpectations(t)
runnerCli.AssertExpectations(t)
}
func TestInferSourceType(t *testing.T) {
testCases := []struct {
url string
expected string
}{
{"docker://test/repo", resource.SourceTypeOCIImage},
{"oci:test/repo", resource.SourceTypeOCIImage},
{"s3://bucket/key", resource.SourceTypeS3},
{"gs://bucket/key", resource.SourceTypeGCS},
{"https://example.com/file", resource.SourceTypeHTTPS},
{"http://example.com/file", resource.SourceTypeHTTP},
{"abc://example.com/file", ""},
{"ftp://example.com/file", ""},
{"unknown://example.com/file", ""},
{"malformed-url", ""},
{"", ""},
}
for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) {
result := inferSourceType(tc.url)
assert.Equal(t, tc.expected, result)
})
}
}
+66
View File
@@ -0,0 +1,66 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"crypto/aes"
"crypto/cipher"
"fmt"
"os"
)
const (
// AES-GCM nonce size in bytes.
aesGCMNonceSize = 12
// AES-256 key size in bytes.
aes256KeySize = 32
)
// DecryptFile decrypts an AES-256-GCM encrypted file in place.
// The encrypted file format is: nonce (12 bytes) || ciphertext+tag.
// The key must be exactly 32 bytes (AES-256).
func DecryptFile(encryptedPath string, key []byte) ([]byte, error) {
if len(key) != aes256KeySize {
return nil, fmt.Errorf("invalid key size: expected %d bytes, got %d", aes256KeySize, len(key))
}
ciphertext, err := os.ReadFile(encryptedPath)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted file: %w", err)
}
return DecryptData(ciphertext, key)
}
// DecryptData decrypts AES-256-GCM encrypted data.
// The encrypted data format is: nonce (12 bytes) || ciphertext+tag.
func DecryptData(ciphertext, key []byte) ([]byte, error) {
if len(key) != aes256KeySize {
return nil, fmt.Errorf("invalid key size: expected %d bytes, got %d", aes256KeySize, len(key))
}
if len(ciphertext) < aesGCMNonceSize {
return nil, fmt.Errorf("ciphertext too short: expected at least %d bytes for nonce, got %d", aesGCMNonceSize, len(ciphertext))
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM cipher: %w", err)
}
nonce := ciphertext[:aesGCMNonceSize]
encData := ciphertext[aesGCMNonceSize:]
plaintext, err := gcm.Open(nil, nonce, encData, nil)
if err != nil {
return nil, fmt.Errorf("decryption failed (authentication error): %w", err)
}
return plaintext, nil
}
+80
View File
@@ -0,0 +1,80 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDecryptFile(t *testing.T) {
key := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, key)
require.NoError(t, err)
plaintext := []byte("hello world")
// Encrypt data
block, err := aes.NewCipher(key)
require.NoError(t, err)
gcm, err := cipher.NewGCM(block)
require.NoError(t, err)
nonce := make([]byte, gcm.NonceSize())
_, err = io.ReadFull(rand.Reader, nonce)
require.NoError(t, err)
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
tmpDir := t.TempDir()
encryptedPath := filepath.Join(tmpDir, "encrypted.bin")
err = os.WriteFile(encryptedPath, ciphertext, 0o644)
require.NoError(t, err)
t.Run("Successful decryption", func(t *testing.T) {
decrypted, err := DecryptFile(encryptedPath, key)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
})
t.Run("Invalid key size", func(t *testing.T) {
_, err := DecryptFile(encryptedPath, key[:16])
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid key size")
})
t.Run("File not found", func(t *testing.T) {
_, err := DecryptFile(filepath.Join(tmpDir, "nonexistent"), key)
assert.Error(t, err)
})
t.Run("Ciphertext too short", func(t *testing.T) {
shortPath := filepath.Join(tmpDir, "short.bin")
err = os.WriteFile(shortPath, []byte("short"), 0o644)
require.NoError(t, err)
_, err = DecryptFile(shortPath, key)
assert.Error(t, err)
assert.Contains(t, err.Error(), "ciphertext too short")
})
t.Run("Decryption failed (auth error)", func(t *testing.T) {
wrongKey := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, wrongKey)
require.NoError(t, err)
_, err = DecryptFile(encryptedPath, wrongKey)
assert.Error(t, err)
assert.Contains(t, err.Error(), "decryption failed")
})
}
+64
View File
@@ -0,0 +1,64 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Package resource provides abstractions for downloading remote resources
// from various sources (OCI registries, S3, HTTP/HTTPS).
package resource
import (
"context"
"fmt"
"sync"
)
// Downloader defines the interface for downloading resources from a remote source.
type Downloader interface {
// Download fetches a resource from the given URL and writes it to destPath.
// For OCI images, destPath is a directory. For S3/HTTP, destPath is a file path.
Download(ctx context.Context, url string, destPath string) error
// Type returns the source type identifier (e.g., "oci-image", "s3", "https", "http").
Type() string
}
// Registry maps source type strings to Downloader implementations.
type Registry struct {
mu sync.RWMutex
downloaders map[string]Downloader
}
// NewRegistry creates a new empty downloader registry.
func NewRegistry() *Registry {
return &Registry{
downloaders: make(map[string]Downloader),
}
}
// Register adds a downloader to the registry for its declared type.
func (r *Registry) Register(d Downloader) {
r.mu.Lock()
defer r.mu.Unlock()
r.downloaders[d.Type()] = d
}
// Get retrieves a downloader for the given source type.
func (r *Registry) Get(sourceType string) (Downloader, error) {
r.mu.RLock()
defer r.mu.RUnlock()
d, ok := r.downloaders[sourceType]
if !ok {
return nil, fmt.Errorf("unsupported source type: %s", sourceType)
}
return d, nil
}
// SupportedTypes returns a list of all registered source types.
func (r *Registry) SupportedTypes() []string {
r.mu.RLock()
defer r.mu.RUnlock()
types := make([]string, 0, len(r.downloaders))
for t := range r.downloaders {
types = append(types, t)
}
return types
}
+53
View File
@@ -0,0 +1,53 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"context"
"testing"
)
type dummyDownloader struct {
typ string
}
func (d *dummyDownloader) Download(ctx context.Context, url, destPath string) error {
return nil
}
func (d *dummyDownloader) Type() string {
return d.typ
}
func TestRegistry(t *testing.T) {
reg := NewRegistry()
// Initially empty
if len(reg.SupportedTypes()) != 0 {
t.Fatalf("expected 0 supported types, got %d", len(reg.SupportedTypes()))
}
// Register a downloader
d1 := &dummyDownloader{typ: "test1"}
reg.Register(d1)
if len(reg.SupportedTypes()) != 1 {
t.Fatalf("expected 1 supported type, got %d", len(reg.SupportedTypes()))
}
// Get the downloader
got, err := reg.Get("test1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != d1 {
t.Fatalf("expected to get identical downloader")
}
// Unknown type
_, err = reg.Get("test2")
if err == nil {
t.Fatalf("expected error for unknown type")
}
}
+89
View File
@@ -0,0 +1,89 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
)
const (
// SourceTypeHTTPS represents an HTTPS resource source.
SourceTypeHTTPS = "https"
// SourceTypeHTTP represents an HTTP resource source.
SourceTypeHTTP = "http"
httpTimeout = 5 * time.Minute
)
// HTTPDownloader downloads resources via HTTP/HTTPS.
type HTTPDownloader struct {
client *http.Client
sourceTyp string
}
// NewHTTPSDownloader creates a new HTTPS downloader.
func NewHTTPSDownloader() *HTTPDownloader {
return &HTTPDownloader{
client: &http.Client{
Timeout: httpTimeout,
},
sourceTyp: SourceTypeHTTPS,
}
}
// NewHTTPDownloader creates a new HTTP downloader (insecure, for testing).
func NewHTTPDownloader() *HTTPDownloader {
return &HTTPDownloader{
client: &http.Client{
Timeout: httpTimeout,
},
sourceTyp: SourceTypeHTTP,
}
}
// Download fetches a resource from an HTTP/HTTPS URL and writes it to destPath.
func (h *HTTPDownloader) Download(ctx context.Context, url string, destPath string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %w", err)
}
resp, err := h.client.Do(req)
if err != nil {
return fmt.Errorf("HTTP request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP request returned status %d: %s", resp.StatusCode, resp.Status)
}
// Ensure parent directory exists.
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
f, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
}
defer f.Close()
if _, err := io.Copy(f, resp.Body); err != nil {
return fmt.Errorf("failed to write response body: %w", err)
}
return nil
}
// Type returns the source type identifier.
func (h *HTTPDownloader) Type() string {
return h.sourceTyp
}
+64
View File
@@ -0,0 +1,64 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
func TestHTTPDownloader(t *testing.T) {
testContent := "test content"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/test" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(testContent))
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
defer ts.Close()
d := NewHTTPDownloader()
if d.Type() != SourceTypeHTTP {
t.Fatalf("expected type %s, got %s", SourceTypeHTTP, d.Type())
}
tmpDir := t.TempDir()
destPath := filepath.Join(tmpDir, "downloaded.txt")
ctx := context.Background()
// Test successful download
err := d.Download(ctx, ts.URL+"/test", destPath)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
content, err := os.ReadFile(destPath)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(content) != testContent {
t.Fatalf("expected content %q, got %q", testContent, string(content))
}
// Test 404
err = d.Download(ctx, ts.URL+"/notfound", filepath.Join(tmpDir, "notfound.txt"))
if err == nil {
t.Fatalf("expected error for 404 response")
}
}
func TestHTTPSDownloader(t *testing.T) {
d := NewHTTPSDownloader()
if d.Type() != SourceTypeHTTPS {
t.Fatalf("expected type %s, got %s", SourceTypeHTTPS, d.Type())
}
}
+58
View File
@@ -0,0 +1,58 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"context"
"github.com/ultravioletrs/cocos/pkg/oci"
)
const (
// SourceTypeOCIImage represents an OCI image resource source.
SourceTypeOCIImage = "oci-image"
)
// OCIClient defines the interface for OCI image operations.
type OCIClient interface {
PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error
ToDockerArchive(ctx context.Context, ociDir, destFile string) error
}
// OCIDownloader adapts OCIClient to the Downloader interface.
// For OCI images, destPath is a directory where the OCI layout is written.
type OCIDownloader struct {
client OCIClient
}
// NewOCIDownloader creates a new OCI downloader wrapping an OCI client.
func NewOCIDownloader(client OCIClient) *OCIDownloader {
return &OCIDownloader{
client: client,
}
}
// Download pulls an OCI image to the destination directory.
// Note: For OCI images, encryption/decryption is handled by Skopeo + CoCo Keyprovider
// transparently via ocicrypt, so this just does the pull.
func (o *OCIDownloader) Download(ctx context.Context, url string, destDir string) error {
source := oci.ResourceSource{
Type: oci.ResourceTypeOCIImage,
URI: url,
// Encryption handled separately by the caller who sets up Skopeo env.
Encrypted: false,
}
return o.client.PullAndDecrypt(ctx, source, destDir)
}
// Type returns the source type identifier.
func (o *OCIDownloader) Type() string {
return SourceTypeOCIImage
}
// Client returns the underlying OCIClient for OCI-specific operations
// like ToDockerArchive that aren't part of the generic Downloader interface.
func (o *OCIDownloader) Client() OCIClient {
return o.client
}
+57
View File
@@ -0,0 +1,57 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/pkg/oci"
)
type MockOCIClient struct {
mock.Mock
}
func (m *MockOCIClient) PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error {
args := m.Called(ctx, source, destDir)
return args.Error(0)
}
func (m *MockOCIClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error {
args := m.Called(ctx, ociDir, destFile)
return args.Error(0)
}
func TestOCIDownloader(t *testing.T) {
mockClient := new(MockOCIClient)
downloader := NewOCIDownloader(mockClient)
ctx := context.Background()
url := "docker://example.com/image:latest"
destDir := "/tmp/oci"
t.Run("Download", func(t *testing.T) {
expectedSource := oci.ResourceSource{
Type: oci.ResourceTypeOCIImage,
URI: url,
Encrypted: false,
}
mockClient.On("PullAndDecrypt", ctx, expectedSource, destDir).Return(nil).Once()
err := downloader.Download(ctx, url, destDir)
assert.NoError(t, err)
mockClient.AssertExpectations(t)
})
t.Run("Type", func(t *testing.T) {
assert.Equal(t, SourceTypeOCIImage, downloader.Type())
})
t.Run("Client", func(t *testing.T) {
assert.Equal(t, mockClient, downloader.Client())
})
}
+161
View File
@@ -0,0 +1,161 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"cloud.google.com/go/storage"
"google.golang.org/api/option"
)
const (
// SourceTypeS3 represents an S3-compatible object storage source.
SourceTypeS3 = "s3"
// SourceTypeGCS represents a Google Cloud Storage source.
SourceTypeGCS = "gcs"
)
// S3Downloader downloads resources from S3-compatible object storage.
// It uses the Google Cloud Storage client library with S3-compatible endpoints
// or can be configured for MinIO/AWS S3 via environment variables.
type S3Downloader struct {
endpoint string // Optional custom endpoint for S3-compatible services (e.g., MinIO).
}
// NewS3Downloader creates a new S3 downloader.
// If endpoint is empty, standard AWS S3 environment credentials/config are used.
func NewS3Downloader(endpoint string) *S3Downloader {
return &S3Downloader{
endpoint: endpoint,
}
}
// Download fetches a resource from an S3 URL (s3://bucket/key) and writes it to destPath.
func (s *S3Downloader) Download(ctx context.Context, url string, destPath string) error {
bucket, key, err := parseS3URL(url)
if err != nil {
return err
}
// Ensure parent directory exists.
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
// Use Google Cloud Storage client with S3-compatible XML API when endpoint is set.
// For standard GCS, use default credentials.
var opts []option.ClientOption
if s.endpoint != "" {
opts = append(opts, option.WithEndpoint(s.endpoint))
}
client, err := storage.NewClient(ctx, opts...)
if err != nil {
return fmt.Errorf("failed to create storage client: %w", err)
}
defer client.Close()
reader, err := client.Bucket(bucket).Object(key).NewReader(ctx)
if err != nil {
return fmt.Errorf("failed to read object %s/%s: %w", bucket, key, err)
}
defer reader.Close()
f, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
}
defer f.Close()
if _, err := f.ReadFrom(reader); err != nil {
return fmt.Errorf("failed to write object content: %w", err)
}
return nil
}
// Type returns the source type identifier.
func (s *S3Downloader) Type() string {
return SourceTypeS3
}
// GCSDownloader downloads resources from Google Cloud Storage.
type GCSDownloader struct{}
// NewGCSDownloader creates a new GCS downloader.
func NewGCSDownloader() *GCSDownloader {
return &GCSDownloader{}
}
// Download fetches a resource from a GCS URL (gs://bucket/key) and writes it to destPath.
func (g *GCSDownloader) Download(ctx context.Context, url string, destPath string) error {
bucket, key, err := parseGCSURL(url)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
client, err := storage.NewClient(ctx)
if err != nil {
return fmt.Errorf("failed to create GCS client: %w", err)
}
defer client.Close()
reader, err := client.Bucket(bucket).Object(key).NewReader(ctx)
if err != nil {
return fmt.Errorf("failed to read object gs://%s/%s: %w", bucket, key, err)
}
defer reader.Close()
f, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
}
defer f.Close()
if _, err := f.ReadFrom(reader); err != nil {
return fmt.Errorf("failed to write object content: %w", err)
}
return nil
}
// Type returns the source type identifier.
func (g *GCSDownloader) Type() string {
return SourceTypeGCS
}
// parseS3URL parses an S3 URL of the form s3://bucket/key.
func parseS3URL(url string) (bucket, key string, err error) {
if !strings.HasPrefix(url, "s3://") {
return "", "", fmt.Errorf("invalid S3 URL, expected s3://bucket/key, got: %s", url)
}
path := strings.TrimPrefix(url, "s3://")
parts := strings.SplitN(path, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", fmt.Errorf("invalid S3 URL, expected s3://bucket/key, got: %s", url)
}
return parts[0], parts[1], nil
}
// parseGCSURL parses a GCS URL of the form gs://bucket/key.
func parseGCSURL(url string) (bucket, key string, err error) {
if !strings.HasPrefix(url, "gs://") {
return "", "", fmt.Errorf("invalid GCS URL, expected gs://bucket/key, got: %s", url)
}
path := strings.TrimPrefix(url, "gs://")
parts := strings.SplitN(path, "/", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", fmt.Errorf("invalid GCS URL, expected gs://bucket/key, got: %s", url)
}
return parts[0], parts[1], nil
}
+128
View File
@@ -0,0 +1,128 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package resource
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseS3URL(t *testing.T) {
tests := []struct {
url string
bucket string
key string
err bool
}{
{"s3://my-bucket/my-key", "my-bucket", "my-key", false},
{"s3://my-bucket/path/to/my-key", "my-bucket", "path/to/my-key", false},
{"s3://my-bucket/", "", "", true},
{"s3://", "", "", true},
{"http://my-bucket/my-key", "", "", true},
{"s3://my-bucket", "", "", true},
}
for _, tt := range tests {
bucket, key, err := parseS3URL(tt.url)
if tt.err {
if err == nil {
t.Errorf("expected error for %s, got nil", tt.url)
}
} else {
if err != nil {
t.Errorf("expected no error for %s, got %v", tt.url, err)
}
if bucket != tt.bucket {
t.Errorf("expected bucket %s, got %s", tt.bucket, bucket)
}
if key != tt.key {
t.Errorf("expected key %s, got %s", tt.key, key)
}
}
}
}
func TestParseGCSURL(t *testing.T) {
tests := []struct {
url string
bucket string
key string
err bool
}{
{"gs://my-bucket/my-key", "my-bucket", "my-key", false},
{"gs://my-bucket/path/to/my-key", "my-bucket", "path/to/my-key", false},
{"gs://my-bucket/", "", "", true},
{"gs://", "", "", true},
{"http://my-bucket/my-key", "", "", true},
{"gs://my-bucket", "", "", true},
}
for _, tt := range tests {
bucket, key, err := parseGCSURL(tt.url)
if tt.err {
if err == nil {
t.Errorf("expected error for %s, got nil", tt.url)
}
} else {
if err != nil {
t.Errorf("expected no error for %s, got %v", tt.url, err)
}
if bucket != tt.bucket {
t.Errorf("expected bucket %s, got %s", tt.bucket, bucket)
}
if key != tt.key {
t.Errorf("expected key %s, got %s", tt.key, key)
}
}
}
}
func TestS3DownloaderErrors(t *testing.T) {
ctx := context.Background()
d := NewS3Downloader("")
assert.Equal(t, SourceTypeS3, d.Type())
t.Run("Invalid URL", func(t *testing.T) {
err := d.Download(ctx, "invalid-url", "dest")
assert.Error(t, err)
})
t.Run("Failed to create directory", func(t *testing.T) {
tmpFile, err := os.CreateTemp("", "blocked")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
err = d.Download(ctx, "s3://bucket/key", filepath.Join(tmpFile.Name(), "subdir", "file"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to create destination directory")
})
}
func TestGCSDownloaderErrors(t *testing.T) {
ctx := context.Background()
d := NewGCSDownloader()
assert.Equal(t, SourceTypeGCS, d.Type())
t.Run("Invalid URL", func(t *testing.T) {
err := d.Download(ctx, "invalid-url", "dest")
assert.Error(t, err)
})
t.Run("Failed to create directory", func(t *testing.T) {
tmpFile, err := os.CreateTemp("", "blocked-gcs")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
err = d.Download(ctx, "gs://bucket/key", filepath.Join(tmpFile.Name(), "subdir", "file"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to create destination directory")
})
}
+19 -4
View File
@@ -45,8 +45,10 @@ var (
// Remote resource configuration.
kbsURL string
algoSourceURL string
algoSourceType string
algoKBSResourcePath string
datasetSourceURLs string
datasetSourceType string
datasetKBSPaths string
algoType string
algoArgsString string
@@ -118,12 +120,16 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
}
for i := 0; i < len(datasetURLs); i++ {
srcType := datasetSourceType
if srcType == "" {
srcType = "oci-image"
}
datasets = append(datasets, &cvms.Dataset{
Hash: dataHashBytes,
UserKey: pubPem.Bytes,
Filename: fmt.Sprintf("dataset_%d.csv", i),
Source: &cvms.Source{
Type: "oci-image",
Type: srcType,
Url: datasetURLs[i],
KbsResourcePath: datasetKBSPathsList[i],
Encrypted: datasetKBSPathsList[i] != "",
@@ -177,13 +183,20 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
algoArgs = strings.Split(algoArgsString, ",")
}
var algoSrcType string
if algoSourceType != "" {
algoSrcType = algoSourceType
} else {
algoSrcType = "oci-image"
}
algorithm = &cvms.Algorithm{
Hash: algoHashBytes,
UserKey: pubPem.Bytes,
AlgoType: algoType,
AlgoArgs: algoArgs,
Source: &cvms.Source{
Type: "oci-image",
Type: algoSrcType,
Url: algoSourceURL,
KbsResourcePath: algoKBSResourcePath,
Encrypted: algoKBSResourcePath != "",
@@ -259,14 +272,16 @@ func main() {
flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path")
// Remote resource flags
flagSet.StringVar(&kbsURL, "kbs-url", "", "KBS endpoint URL (e.g., 'http://localhost:8080')")
flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (s3://bucket/key or https://...)")
flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (docker://..., s3://..., https://..., etc.)")
flagSet.StringVar(&algoSourceType, "algo-source-type", "", "Algorithm source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.")
flagSet.StringVar(&algoKBSResourcePath, "algo-kbs-path", "", "Algorithm KBS resource path (e.g., 'default/key/algo-key')")
flagSet.StringVar(&datasetSourceURLs, "dataset-source-urls", "", "Dataset source URLs, comma-separated")
flagSet.StringVar(&datasetSourceType, "dataset-source-type", "", "Dataset source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.")
flagSet.StringVar(&datasetKBSPaths, "dataset-kbs-paths", "", "Dataset KBS resource paths, comma-separated")
flagSet.StringVar(&algoType, "algo-type", "", "Algorithm execution type (e.g. binary, python)")
flagSet.StringVar(&algoArgsString, "algo-args", "", "Algorithm arguments, comma-separated")
flagSet.StringVar(&algoHash, "algo-hash", "", "Algorithm SHA256 hash (hex string)")
flagSet.StringVar(&datasetTypeString, "dataset-type", "", "Dataset source type, comma-separated (deprecated, always oci-image)")
flagSet.StringVar(&datasetTypeString, "dataset-type", "", "Dataset source type (deprecated, use --dataset-source-type)")
flagSet.StringVar(&datasetHash, "dataset-hash", "", "Dataset SHA256 hash (hex string)")
flagSet.StringVar(&datasetDecompress, "dataset-decompress", "", "Dataset decompression bools, comma-separated (e.g. true,false)")