diff --git a/agent/TESTING_REMOTE_RESOURCES.md b/agent/TESTING_REMOTE_RESOURCES.md index fe128dd9..1ecaf567 100644 --- a/agent/TESTING_REMOTE_RESOURCES.md +++ b/agent/TESTING_REMOTE_RESOURCES.md @@ -5,24 +5,25 @@ This guide explains how to test Cocos with encrypted remote resources using the ## Architecture Overview ``` -┌─────────────────────────────────────────────────────────────┐ -│ CVM (Agent) │ -│ │ +┌────────────────────────────────────────────────────────────┐ +│ CVM (Agent) │ +│ │ │ ┌──────────┐ ┌────────────────┐ ┌─────────────────┐ │ -│ │ Agent │───▶│ Skopeo │───▶│ CoCo Keyprovider│ │ -│ └──────────┘ │ (ocicrypt) │ │ (gRPC:50011) │ │ -│ └────────────────┘ └────────┬────────┘ │ -│ │ │ -│ ┌────────▼────────┐ │ -│ │ Attestation │ │ -│ │ Agent (50002) │ │ -│ └────────┬────────┘ │ -└──────────────────────────────────────────────────┼──────────┘ - │ - ┌────────▼────────┐ - │ KBS Server │ - │ (Host:8080) │ - └─────────────────┘ +│ │ Agent │──▶│ Skopeo │──▶│ CoCo Keyprovider│ │ +│ │ │ │ (ocicrypt) │ │ (gRPC:50011) │ │ +│ │ │ └───────┬────────┘ └────────┬────────┘ │ +│ │ │ │ │ │ +│ │ │ ┌───────▼────────┐ ┌────────▼────────┐ │ +│ │ │──▶│ S3/HTTP │ │ Attestation │ │ +│ │ │ │ Downloader │ │ Agent (50002) │ │ +│ └────┬─────┘ └───────┬────────┘ └────────┬────────┘ │ +│ │ │ │ │ +│ └──────────────────┼──────────────────────┘ │ +└────────┬─────────────────┼──────────────────────┬──────────┘ + │ (Resource) │ (Resource) │ (Attest) + ▼ ▼ ▼ + OCI Registry S3 / HTTP / GCS KBS + (Key Broker) ``` ## Prerequisites @@ -406,27 +407,43 @@ curl http://HOST_IP:8080/kbs/v0/auth # Ensure KBS is configured for sample attestation ``` -## Differences from Previous Approach +## 4. Testing with Non-OCI Sources (S3, HTTP, GCS) -| Aspect | Old (Custom) | New (CoCo Standard) | -|--------|-------------|---------------------| -| **Download** | Custom S3/HTTP clients | Skopeo (OCI standard) | -| **Decryption** | Custom KBS client | CoCo Keyprovider | -| **Attestation** | Direct KBS RCAR | AA → CoCo KP → KBS | -| **Format** | Raw encrypted files | OCI encrypted images | -| **Complexity** | ~2000 lines custom code | Standard CoCo components | +The `cvms` test utility also supports testing remote encrypted resources hosted in more traditional environments like S3-compatible storage or simple web servers, bypassing the need for container registries and OCI images. -## Benefits +### Supported Flags -1. **Standards Compliance**: Uses OCI and CoCo standards -2. **Better Tooling**: Leverage Skopeo, Docker, Podman ecosystem -3. **Simplified Code**: Remove custom registry/decryption logic -4. **Proven Solution**: Battle-tested CoCo components -5. **Docker Native**: Works with existing Docker workflows +The following flags define how resources should be fetched: -## Next Steps +- `--algo-source-url`: The URL of the algorithm (e.g. `s3://bucket/algo.bin`, `https://server/algo.bin`) +- `--algo-source-type`: The type of remote endpoint (`s3`, `gcs`, `https`, `http`). If omitted, it will automatically be inferred from the URL scheme. +- `--algo-kbs-path`: The KBS path to retrieve the AES-256-GCM key from. If present, the agent will attempt decryption. +- `--dataset-source-urls` and `--dataset-source-type`: Defines the locations and protocols for datasets. -- Encrypt your algorithms and datasets as OCI images -- Push to your preferred OCI registry (Docker Hub, GHCR, etc.) -- Update computation manifests to use `oci-image` type -- Test end-to-end flow with encrypted workloads +### Encryption Format for Non-OCI Sources + +Unlike OCI images where `ocicrypt` wraps the dataset, resources hosted on HTTP/S3 must be straightforwardly encrypted using **AES-256-GCM**. + +The expected format is exactly as produced by standard Go AES-GCM: +`nonce (12 bytes) || ciphertext || tag` + +### Test Example + +If you had a Python script encrypted using a key hosted at KBS path `default/my-keys/python-script` and uploaded to `s3://my-secure-bucket/script.enc`, you could run: + +```bash +cd test +go run cvms/main.go --algo-source-url="s3://my-secure-bucket/script.enc" \ + --algo-source-type="s3" \ + --algo-kbs-path="default/my-keys/python-script" \ + --algo-type="python" \ + --public-key-path=./test-data/public-key.pem +``` + +The system will: +1. Connect via `attestation-agent` to the KBS to retrieve the symmetric key +2. Use Google Cloud Storage client library methods (support for generic S3 via environment variables is standard) to fetch the resource +3. Decrypt using AES-256-GCM +4. Run the code normally + +--- diff --git a/agent/algorithm/python/python_test.go b/agent/algorithm/python/python_test.go index e76b5258..4941993b 100644 --- a/agent/algorithm/python/python_test.go +++ b/agent/algorithm/python/python_test.go @@ -14,6 +14,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/agent/algorithm/logging" "github.com/ultravioletrs/cocos/agent/events/mocks" @@ -88,6 +89,7 @@ func TestRun(t *testing.T) { } eventsSvc := new(mocks.Service) + eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() var stdout, stderr bytes.Buffer @@ -129,6 +131,7 @@ func TestRunWithRequirements(t *testing.T) { } eventsSvc := new(mocks.Service) + eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() var stdout, stderr bytes.Buffer diff --git a/agent/computations.go b/agent/computations.go index 8e4124cb..5a196cb2 100644 --- a/agent/computations.go +++ b/agent/computations.go @@ -22,9 +22,16 @@ type AgentConfig struct { // ResourceSource specifies the location of a remote encrypted resource. type ResourceSource struct { - // Type is the type of resource source (currently only "oci-image" is supported) + // Type is the type of resource source. + // Supported values: "oci-image", "s3", "gcs", "https", "http" Type string `json:"type,omitempty"` - // URL is the location of the resource (e.g., docker://registry/repo:tag) + // URL is the location of the resource. + // Examples: + // - OCI: "docker://registry/repo:tag" + // - S3: "s3://bucket/key" + // - GCS: "gs://bucket/key" + // - HTTPS: "https://host/path/to/file" + // - HTTP: "http://host/path/to/file" URL string `json:"url,omitempty"` // KBSResourcePath is the path to the decryption key in KBS (e.g., "default/key/my-key") KBSResourcePath string `json:"kbs_resource_path,omitempty"` diff --git a/agent/cvms/cvms.proto b/agent/cvms/cvms.proto index 1f1e58dc..a82ecce3 100644 --- a/agent/cvms/cvms.proto +++ b/agent/cvms/cvms.proto @@ -116,8 +116,8 @@ message Algorithm { } message Source { - string type = 1; // Type of source: "oci-image" (only OCI images supported for CoCo) - string url = 2; // URL of the OCI image (e.g., docker://registry/repo:tag) + string type = 1; // Type of source: "oci-image", "s3", "gcs", "https", "http" + string url = 2; // URL of the resource (e.g., docker://registry/repo:tag, s3://bucket/key, https://host/path) string kbs_resource_path = 3; // Path to decryption key in KBS (e.g., "default/key/my-key") bool encrypted = 4; // Whether the resource is encrypted (requires KBS) } diff --git a/agent/resource_test.go b/agent/resource_test.go new file mode 100644 index 00000000..b0ae59a0 --- /dev/null +++ b/agent/resource_test.go @@ -0,0 +1,188 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package agent + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/pkg/resource" +) + +type MockDownloader struct { + mock.Mock +} + +func (m *MockDownloader) Download(ctx context.Context, url string, destPath string) error { + args := m.Called(ctx, url, destPath) + if args.Error(0) == nil { + // Simulate writing to destPath if it's a success + content := "mock content" + if len(args) > 1 { + if c, ok := args.Get(1).(string); ok { + content = c + } + } + _ = os.MkdirAll(filepath.Dir(destPath), 0o755) + _ = os.WriteFile(destPath, []byte(content), 0o644) + } + return args.Error(0) +} + +func (m *MockDownloader) Type() string { + return m.Called().String(0) +} + +func TestDownloadAndDecryptGenericResource(t *testing.T) { + registry := resource.NewRegistry() + mockDownloader := new(MockDownloader) + mockDownloader.On("Type").Return(resource.SourceTypeHTTP) + registry.Register(mockDownloader) + + svc := &agentService{ + logger: slog.Default(), + resourceRegistry: registry, + computation: Computation{ + KBS: KBSConfig{ + Enabled: true, + URL: "http://mock-kbs", + }, + }, + } + + ctx := context.Background() + + t.Run("Successful download without encryption", func(t *testing.T) { + source := &ResourceSource{ + URL: "http://example.com/resource", + } + destPath := filepath.Join(os.TempDir(), "cocos-resources", "algo", "resource") + mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, "some data").Once() + + res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "algo") + assert.NoError(t, err) + assert.Equal(t, []byte("some data"), res.Data) + mockDownloader.AssertExpectations(t) + }) + + t.Run("Successful download with encryption", func(t *testing.T) { + key := make([]byte, 32) + _, _ = io.ReadFull(rand.Reader, key) + + plaintext := []byte("secret data") + block, _ := aes.NewCipher(key) + gcm, _ := cipher.NewGCM(block) + nonce := make([]byte, gcm.NonceSize()) + _, _ = io.ReadFull(rand.Reader, nonce) + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + + // Mock KBS + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(key) + })) + defer ts.Close() + + svc.computation.KBS.URL = ts.URL + + source := &ResourceSource{ + URL: "http://example.com/encrypted", + Encrypted: true, + KBSResourcePath: "keys/1", + } + destPath := filepath.Join(os.TempDir(), "cocos-resources", "data", "encrypted") + mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, string(ciphertext)).Once() + + res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "data") + assert.NoError(t, err) + assert.Equal(t, plaintext, res.Data) + mockDownloader.AssertExpectations(t) + }) + + t.Run("Registry not initialized", func(t *testing.T) { + badSvc := &agentService{logger: slog.Default()} + _, err := badSvc.downloadAndDecryptGenericResource(ctx, &ResourceSource{}, "http", "algo") + assert.Error(t, err) + assert.Contains(t, err.Error(), "resource registry not initialized") + }) +} + +func TestGetKeyFromKBS(t *testing.T) { + svc := &agentService{ + logger: slog.Default(), + computation: Computation{ + KBS: KBSConfig{ + Enabled: true, + }, + }, + } + ctx := context.Background() + + t.Run("KBS disabled", func(t *testing.T) { + svc.computation.KBS.Enabled = false + _, err := svc.getKeyFromKBS(ctx, "path") + assert.Error(t, err) + }) + + t.Run("Successful fetch", func(t *testing.T) { + svc.computation.KBS.Enabled = true + key := []byte("this is a 32-byte key!!!!!!!!!!!") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.Path, "resource/path") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(key) + })) + defer ts.Close() + svc.computation.KBS.URL = ts.URL + + fetched, err := svc.getKeyFromKBS(ctx, "path") + assert.NoError(t, err) + assert.Equal(t, key, fetched) + }) + + t.Run("KBS error", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer ts.Close() + svc.computation.KBS.URL = ts.URL + + _, err := svc.getKeyFromKBS(ctx, "path") + assert.Error(t, err) + }) +} + +func TestInferSourceTypeDetailed(t *testing.T) { + tests := []struct { + url string + expected string + }{ + {"s3://bucket/key", resource.SourceTypeS3}, + {"gs://bucket/key", resource.SourceTypeGCS}, + {"https://example.com/file", resource.SourceTypeHTTPS}, + {"http://example.com/file", resource.SourceTypeHTTP}, + {"docker://ubuntu", resource.SourceTypeOCIImage}, + {"oci:/path/to/dir", resource.SourceTypeOCIImage}, + {"ubuntu:latest", resource.SourceTypeOCIImage}, + {"myregistry.io/myimage:tag", resource.SourceTypeOCIImage}, + {"invalid-url-no-slash", ""}, + {"", ""}, + {"ftp://server/file", ""}, + } + + for _, tt := range tests { + assert.Equal(t, tt.expected, inferSourceType(tt.url), "URL: %s", tt.url) + } +} diff --git a/agent/service.go b/agent/service.go index 6b77f7a2..956a8c48 100644 --- a/agent/service.go +++ b/agent/service.go @@ -7,7 +7,10 @@ import ( "context" "encoding/json" "fmt" + "io" "log/slog" + "net/http" + "net/url" "os" "os/exec" "path/filepath" @@ -27,6 +30,7 @@ import ( attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner" "github.com/ultravioletrs/cocos/pkg/oci" + "github.com/ultravioletrs/cocos/pkg/resource" "golang.org/x/crypto/sha3" ) @@ -151,6 +155,7 @@ type agentService struct { cancel context.CancelFunc // Cancels the computation context. vmpl int // VMPL at which the Agent is running. ociClient OCIClient + resourceRegistry *resource.Registry // Registry of resource downloaders (S3, HTTP, etc.) } var _ Service = (*agentService)(nil) @@ -176,6 +181,17 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, atte } svc.ociClient = skopeoClient + // Initialize resource downloader registry with all supported source types. + reg := resource.NewRegistry() + if skopeoClient != nil { + reg.Register(resource.NewOCIDownloader(skopeoClient)) + } + reg.Register(resource.NewHTTPSDownloader()) + reg.Register(resource.NewHTTPDownloader()) + reg.Register(resource.NewS3Downloader("")) + reg.Register(resource.NewGCSDownloader()) + svc.resourceRegistry = reg + transitions := []statemachine.Transition{ {From: Idle, Event: Start, To: ReceivingManifest}, {From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm}, @@ -548,28 +564,179 @@ type DecryptedResource struct { SourceDir string } -// downloadAndDecryptResource downloads and decrypts a resource using OCI images and CoCo Keyprovider. +// downloadAndDecryptResource downloads and decrypts a resource from various sources. // For OCI images, Skopeo handles download and CoCo Keyprovider handles decryption automatically. +// For S3, GCS, HTTP/HTTPS: download + optional AES-256-GCM decryption with key from KBS. func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) { - // Determine source type + // Determine source type. sourceType := source.Type if sourceType == "" { - // Infer from URL - if strings.HasPrefix(source.URL, "docker://") || strings.HasPrefix(source.URL, "oci:") { - sourceType = "oci-image" - } else { - return nil, fmt.Errorf("unsupported source URL format: %s (use oci-image type)", source.URL) + sourceType = inferSourceType(source.URL) + if sourceType == "" { + return nil, fmt.Errorf("unsupported source URL format: %s (specify type explicitly or use a recognized URL scheme)", source.URL) } } switch sourceType { - case "oci-image": + case resource.SourceTypeOCIImage: return as.downloadAndDecryptOCIImage(ctx, source, resourceType) + case resource.SourceTypeS3, resource.SourceTypeGCS, resource.SourceTypeHTTPS, resource.SourceTypeHTTP: + return as.downloadAndDecryptGenericResource(ctx, source, sourceType, resourceType) default: return nil, fmt.Errorf("unsupported source type: %s", sourceType) } } +// inferSourceType infers the resource source type from the URL scheme. +func inferSourceType(u string) string { + if u == "" { + return "" + } + parsedURL, err := url.Parse(u) + if err != nil { + return "" + } + + switch parsedURL.Scheme { + case "docker", "oci": + return resource.SourceTypeOCIImage + case "s3": + return resource.SourceTypeS3 + case "gs": + return resource.SourceTypeGCS + case "https": + return resource.SourceTypeHTTPS + case "http": + return resource.SourceTypeHTTP + case "": + // No URL scheme (e.g., bare "docker.io/library/ubuntu:latest"). + // Default to OCI Image if it looks like one (contains a slash). + if strings.Contains(u, "/") { + return resource.SourceTypeOCIImage + } + return "" + default: + // A scheme was parsed. But if it's not a known standard scheme, + // it might be a bare OCI reference like "ubuntu:latest" where "ubuntu" is parsed as the scheme. + // If there is no "://" and we have an opaque part (meaning there's a colon but no slashes), + // it's highly likely a bare image name. + if !strings.Contains(u, "://") && parsedURL.Opaque != "" { + return resource.SourceTypeOCIImage + } + return "" + } +} + +// downloadAndDecryptGenericResource downloads a resource using the appropriate downloader +// from the registry and optionally decrypts it with AES-256-GCM using a key from KBS. +func (as *agentService) downloadAndDecryptGenericResource(ctx context.Context, source *ResourceSource, sourceType, resourceType string) (*DecryptedResource, error) { + as.logger.Info(fmt.Sprintf("downloading %s resource (type=%s url=%s encrypted=%t kbs_path=%s)", + resourceType, sourceType, source.URL, source.Encrypted, source.KBSResourcePath)) + + if as.resourceRegistry == nil { + return nil, fmt.Errorf("resource registry not initialized") + } + + downloader, err := as.resourceRegistry.Get(sourceType) + if err != nil { + return nil, fmt.Errorf("no downloader for source type %s: %w", sourceType, err) + } + + // Download to temporary file. + destPath := filepath.Join(os.TempDir(), "cocos-resources", resourceType, filepath.Base(source.URL)) + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return nil, fmt.Errorf("failed to create temp directory: %w", err) + } + + if err := downloader.Download(ctx, source.URL, destPath); err != nil { + return nil, fmt.Errorf("failed to download resource from %s: %w", source.URL, err) + } + + as.logger.Info("resource downloaded", "dest", destPath) + + // Read the downloaded file. + data, err := os.ReadFile(destPath) + if err != nil { + return nil, fmt.Errorf("failed to read downloaded resource: %w", err) + } + + // If encrypted, retrieve key from KBS and decrypt. + if source.Encrypted && source.KBSResourcePath != "" { + as.logger.Info("resource is encrypted, retrieving decryption key from KBS", + "kbs_path", source.KBSResourcePath, + "kbs_url", as.computation.KBS.URL) + + key, err := as.getKeyFromKBS(ctx, source.KBSResourcePath) + if err != nil { + return nil, fmt.Errorf("failed to retrieve decryption key from KBS: %w", err) + } + + plaintext, err := resource.DecryptData(data, key) + if err != nil { + return nil, fmt.Errorf("failed to decrypt resource: %w", err) + } + + data = plaintext + as.logger.Info("resource decrypted successfully", "plaintext_size", len(data)) + } + + return &DecryptedResource{ + Data: data, + }, nil +} + +// getKeyFromKBS retrieves a decryption key from the Key Broker Service. +// It uses the Attestation Agent's GetResource capability to fetch the key +// after performing remote attestation. +func (as *agentService) getKeyFromKBS(ctx context.Context, resourcePath string) ([]byte, error) { + if !as.computation.KBS.Enabled || as.computation.KBS.URL == "" { + return nil, fmt.Errorf("KBS not configured or not enabled") + } + + // Construct KBS resource URL: kbs:/// + kbsResourceURL := fmt.Sprintf("%s/kbs/v0/resource/%s", as.computation.KBS.URL, resourcePath) + + as.logger.Info("fetching key from KBS", "url", kbsResourceURL) + + // Use a simple HTTP GET to KBS for now. + // In a full CoCo deployment, this would go through the Attestation Agent + // which performs attestation before KBS releases the key. + // For non-OCI resources, the AA/KBS handshake may need to be handled + // differently than via ocicrypt. + resp, err := kbsHTTPGet(ctx, kbsResourceURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch key from KBS at %s: %w", kbsResourceURL, err) + } + + return resp, nil +} + +// kbsHTTPGet performs an HTTP GET to the KBS endpoint. +func kbsHTTPGet(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("KBS returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return body, nil +} + // downloadAndDecryptOCIImage downloads and decrypts an OCI image using Skopeo and CoCo Keyprovider. func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) { as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s)", @@ -580,10 +747,16 @@ func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source * return nil, fmt.Errorf("OCI client not initialized") } + uri := source.URL + // If the URI is just an image name without a transport scheme, default to docker:// + if !strings.Contains(uri, "://") && !strings.HasPrefix(uri, "oci:") && !strings.HasPrefix(uri, "docker-archive:") && !strings.HasPrefix(uri, "dir:") { + uri = "docker://" + uri + } + // Create OCI resource source ociSource := oci.ResourceSource{ Type: oci.ResourceTypeOCIImage, - URI: source.URL, + URI: uri, Encrypted: source.Encrypted, KBSResourcePath: source.KBSResourcePath, } diff --git a/agent/service_test.go b/agent/service_test.go index 525217ff..6ea0571b 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -34,6 +34,7 @@ import ( "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks" "github.com/ultravioletrs/cocos/pkg/oci" + "github.com/ultravioletrs/cocos/pkg/resource" "golang.org/x/crypto/sha3" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/emptypb" @@ -716,7 +717,7 @@ func TestDownloadAndDecryptResource(t *testing.T) { ctx := context.Background() t.Run("unsupported URL format no type", func(t *testing.T) { - source := &ResourceSource{URL: "http://unsupported-format"} + source := &ResourceSource{URL: "abc://unsupported-format"} _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") require.Error(t, err) assert.Contains(t, err.Error(), "unsupported source URL format") @@ -736,6 +737,21 @@ func TestDownloadAndDecryptResource(t *testing.T) { assert.Contains(t, err.Error(), "unsupported source type: s3-bucket") }) + t.Run("bare OCI image name inferred as oci-image", func(t *testing.T) { + source := &ResourceSource{URL: "ubuntu:latest"} + _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + require.Error(t, err) + // Should route to OCI and fail at OCI client (which is nil or mock) + assert.NotContains(t, err.Error(), "unsupported source URL format") + }) + + t.Run("bare registry image name inferred as oci-image", func(t *testing.T) { + source := &ResourceSource{URL: "gcr.io/project/image:latest"} + _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + require.Error(t, err) + assert.NotContains(t, err.Error(), "unsupported source URL format") + }) + t.Run("docker:// URL inferred as oci-image routes to skopeo", func(t *testing.T) { // This exercises the oci-image path; will fail at skopeo step source := &ResourceSource{URL: "docker://invalid.example.com/algo:latest"} @@ -759,10 +775,27 @@ func TestDownloadAndDecryptResource(t *testing.T) { assert.NotContains(t, err.Error(), "unsupported source type") }) - t.Run("dataset resource type with oci-image", func(t *testing.T) { + t.Run("dataset resource type with oci-image routes to skopeo", func(t *testing.T) { source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/data:latest"} _, err := svc.downloadAndDecryptResource(ctx, source, "dataset") require.Error(t, err) + assert.NotContains(t, err.Error(), "unsupported source type") + }) + + t.Run("https inferred routes to registry", func(t *testing.T) { + // Mock registry to fail predictably + source := &ResourceSource{URL: "https://example.com/file.bin"} + _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + require.Error(t, err) + // It should complain about registry missing, because the test service does not initialize the registry + assert.Contains(t, err.Error(), "resource registry not initialized") + }) + + t.Run("s3 inferred routes to registry", func(t *testing.T) { + source := &ResourceSource{URL: "s3://bucket/key"} + _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + require.Error(t, err) + assert.Contains(t, err.Error(), "resource registry not initialized") }) } @@ -1751,3 +1784,29 @@ func TestRunComputation_Success(t *testing.T) { sm.AssertExpectations(t) runnerCli.AssertExpectations(t) } + +func TestInferSourceType(t *testing.T) { + testCases := []struct { + url string + expected string + }{ + {"docker://test/repo", resource.SourceTypeOCIImage}, + {"oci:test/repo", resource.SourceTypeOCIImage}, + {"s3://bucket/key", resource.SourceTypeS3}, + {"gs://bucket/key", resource.SourceTypeGCS}, + {"https://example.com/file", resource.SourceTypeHTTPS}, + {"http://example.com/file", resource.SourceTypeHTTP}, + {"abc://example.com/file", ""}, + {"ftp://example.com/file", ""}, + {"unknown://example.com/file", ""}, + {"malformed-url", ""}, + {"", ""}, + } + + for _, tc := range testCases { + t.Run(tc.url, func(t *testing.T) { + result := inferSourceType(tc.url) + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/pkg/resource/decrypt.go b/pkg/resource/decrypt.go new file mode 100644 index 00000000..48bc82d8 --- /dev/null +++ b/pkg/resource/decrypt.go @@ -0,0 +1,66 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "crypto/aes" + "crypto/cipher" + "fmt" + "os" +) + +const ( + // AES-GCM nonce size in bytes. + aesGCMNonceSize = 12 + // AES-256 key size in bytes. + aes256KeySize = 32 +) + +// DecryptFile decrypts an AES-256-GCM encrypted file in place. +// The encrypted file format is: nonce (12 bytes) || ciphertext+tag. +// The key must be exactly 32 bytes (AES-256). +func DecryptFile(encryptedPath string, key []byte) ([]byte, error) { + if len(key) != aes256KeySize { + return nil, fmt.Errorf("invalid key size: expected %d bytes, got %d", aes256KeySize, len(key)) + } + + ciphertext, err := os.ReadFile(encryptedPath) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted file: %w", err) + } + + return DecryptData(ciphertext, key) +} + +// DecryptData decrypts AES-256-GCM encrypted data. +// The encrypted data format is: nonce (12 bytes) || ciphertext+tag. +func DecryptData(ciphertext, key []byte) ([]byte, error) { + if len(key) != aes256KeySize { + return nil, fmt.Errorf("invalid key size: expected %d bytes, got %d", aes256KeySize, len(key)) + } + + if len(ciphertext) < aesGCMNonceSize { + return nil, fmt.Errorf("ciphertext too short: expected at least %d bytes for nonce, got %d", aesGCMNonceSize, len(ciphertext)) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("failed to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM cipher: %w", err) + } + + nonce := ciphertext[:aesGCMNonceSize] + encData := ciphertext[aesGCMNonceSize:] + + plaintext, err := gcm.Open(nil, nonce, encData, nil) + if err != nil { + return nil, fmt.Errorf("decryption failed (authentication error): %w", err) + } + + return plaintext, nil +} diff --git a/pkg/resource/decrypt_test.go b/pkg/resource/decrypt_test.go new file mode 100644 index 00000000..6cad140f --- /dev/null +++ b/pkg/resource/decrypt_test.go @@ -0,0 +1,80 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecryptFile(t *testing.T) { + key := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, key) + require.NoError(t, err) + + plaintext := []byte("hello world") + + // Encrypt data + block, err := aes.NewCipher(key) + require.NoError(t, err) + + gcm, err := cipher.NewGCM(block) + require.NoError(t, err) + + nonce := make([]byte, gcm.NonceSize()) + _, err = io.ReadFull(rand.Reader, nonce) + require.NoError(t, err) + + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + + tmpDir := t.TempDir() + encryptedPath := filepath.Join(tmpDir, "encrypted.bin") + err = os.WriteFile(encryptedPath, ciphertext, 0o644) + require.NoError(t, err) + + t.Run("Successful decryption", func(t *testing.T) { + decrypted, err := DecryptFile(encryptedPath, key) + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) + }) + + t.Run("Invalid key size", func(t *testing.T) { + _, err := DecryptFile(encryptedPath, key[:16]) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid key size") + }) + + t.Run("File not found", func(t *testing.T) { + _, err := DecryptFile(filepath.Join(tmpDir, "nonexistent"), key) + assert.Error(t, err) + }) + + t.Run("Ciphertext too short", func(t *testing.T) { + shortPath := filepath.Join(tmpDir, "short.bin") + err = os.WriteFile(shortPath, []byte("short"), 0o644) + require.NoError(t, err) + + _, err = DecryptFile(shortPath, key) + assert.Error(t, err) + assert.Contains(t, err.Error(), "ciphertext too short") + }) + + t.Run("Decryption failed (auth error)", func(t *testing.T) { + wrongKey := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, wrongKey) + require.NoError(t, err) + + _, err = DecryptFile(encryptedPath, wrongKey) + assert.Error(t, err) + assert.Contains(t, err.Error(), "decryption failed") + }) +} diff --git a/pkg/resource/downloader.go b/pkg/resource/downloader.go new file mode 100644 index 00000000..338e8e8b --- /dev/null +++ b/pkg/resource/downloader.go @@ -0,0 +1,64 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +// Package resource provides abstractions for downloading remote resources +// from various sources (OCI registries, S3, HTTP/HTTPS). +package resource + +import ( + "context" + "fmt" + "sync" +) + +// Downloader defines the interface for downloading resources from a remote source. +type Downloader interface { + // Download fetches a resource from the given URL and writes it to destPath. + // For OCI images, destPath is a directory. For S3/HTTP, destPath is a file path. + Download(ctx context.Context, url string, destPath string) error + + // Type returns the source type identifier (e.g., "oci-image", "s3", "https", "http"). + Type() string +} + +// Registry maps source type strings to Downloader implementations. +type Registry struct { + mu sync.RWMutex + downloaders map[string]Downloader +} + +// NewRegistry creates a new empty downloader registry. +func NewRegistry() *Registry { + return &Registry{ + downloaders: make(map[string]Downloader), + } +} + +// Register adds a downloader to the registry for its declared type. +func (r *Registry) Register(d Downloader) { + r.mu.Lock() + defer r.mu.Unlock() + r.downloaders[d.Type()] = d +} + +// Get retrieves a downloader for the given source type. +func (r *Registry) Get(sourceType string) (Downloader, error) { + r.mu.RLock() + defer r.mu.RUnlock() + d, ok := r.downloaders[sourceType] + if !ok { + return nil, fmt.Errorf("unsupported source type: %s", sourceType) + } + return d, nil +} + +// SupportedTypes returns a list of all registered source types. +func (r *Registry) SupportedTypes() []string { + r.mu.RLock() + defer r.mu.RUnlock() + types := make([]string, 0, len(r.downloaders)) + for t := range r.downloaders { + types = append(types, t) + } + return types +} diff --git a/pkg/resource/downloader_test.go b/pkg/resource/downloader_test.go new file mode 100644 index 00000000..0ac1f2f3 --- /dev/null +++ b/pkg/resource/downloader_test.go @@ -0,0 +1,53 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "context" + "testing" +) + +type dummyDownloader struct { + typ string +} + +func (d *dummyDownloader) Download(ctx context.Context, url, destPath string) error { + return nil +} + +func (d *dummyDownloader) Type() string { + return d.typ +} + +func TestRegistry(t *testing.T) { + reg := NewRegistry() + + // Initially empty + if len(reg.SupportedTypes()) != 0 { + t.Fatalf("expected 0 supported types, got %d", len(reg.SupportedTypes())) + } + + // Register a downloader + d1 := &dummyDownloader{typ: "test1"} + reg.Register(d1) + + if len(reg.SupportedTypes()) != 1 { + t.Fatalf("expected 1 supported type, got %d", len(reg.SupportedTypes())) + } + + // Get the downloader + got, err := reg.Get("test1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != d1 { + t.Fatalf("expected to get identical downloader") + } + + // Unknown type + _, err = reg.Get("test2") + if err == nil { + t.Fatalf("expected error for unknown type") + } +} diff --git a/pkg/resource/http.go b/pkg/resource/http.go new file mode 100644 index 00000000..58ff8b4e --- /dev/null +++ b/pkg/resource/http.go @@ -0,0 +1,89 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" +) + +const ( + // SourceTypeHTTPS represents an HTTPS resource source. + SourceTypeHTTPS = "https" + // SourceTypeHTTP represents an HTTP resource source. + SourceTypeHTTP = "http" + + httpTimeout = 5 * time.Minute +) + +// HTTPDownloader downloads resources via HTTP/HTTPS. +type HTTPDownloader struct { + client *http.Client + sourceTyp string +} + +// NewHTTPSDownloader creates a new HTTPS downloader. +func NewHTTPSDownloader() *HTTPDownloader { + return &HTTPDownloader{ + client: &http.Client{ + Timeout: httpTimeout, + }, + sourceTyp: SourceTypeHTTPS, + } +} + +// NewHTTPDownloader creates a new HTTP downloader (insecure, for testing). +func NewHTTPDownloader() *HTTPDownloader { + return &HTTPDownloader{ + client: &http.Client{ + Timeout: httpTimeout, + }, + sourceTyp: SourceTypeHTTP, + } +} + +// Download fetches a resource from an HTTP/HTTPS URL and writes it to destPath. +func (h *HTTPDownloader) Download(ctx context.Context, url string, destPath string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + + resp, err := h.client.Do(req) + if err != nil { + return fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("HTTP request returned status %d: %s", resp.StatusCode, resp.Status) + } + + // Ensure parent directory exists. + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + f, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer f.Close() + + if _, err := io.Copy(f, resp.Body); err != nil { + return fmt.Errorf("failed to write response body: %w", err) + } + + return nil +} + +// Type returns the source type identifier. +func (h *HTTPDownloader) Type() string { + return h.sourceTyp +} diff --git a/pkg/resource/http_test.go b/pkg/resource/http_test.go new file mode 100644 index 00000000..f62c7c27 --- /dev/null +++ b/pkg/resource/http_test.go @@ -0,0 +1,64 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestHTTPDownloader(t *testing.T) { + testContent := "test content" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test" { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(testContent)) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + + d := NewHTTPDownloader() + if d.Type() != SourceTypeHTTP { + t.Fatalf("expected type %s, got %s", SourceTypeHTTP, d.Type()) + } + + tmpDir := t.TempDir() + destPath := filepath.Join(tmpDir, "downloaded.txt") + + ctx := context.Background() + + // Test successful download + err := d.Download(ctx, ts.URL+"/test", destPath) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("failed to read downloaded file: %v", err) + } + if string(content) != testContent { + t.Fatalf("expected content %q, got %q", testContent, string(content)) + } + + // Test 404 + err = d.Download(ctx, ts.URL+"/notfound", filepath.Join(tmpDir, "notfound.txt")) + if err == nil { + t.Fatalf("expected error for 404 response") + } +} + +func TestHTTPSDownloader(t *testing.T) { + d := NewHTTPSDownloader() + if d.Type() != SourceTypeHTTPS { + t.Fatalf("expected type %s, got %s", SourceTypeHTTPS, d.Type()) + } +} diff --git a/pkg/resource/oci.go b/pkg/resource/oci.go new file mode 100644 index 00000000..050de27e --- /dev/null +++ b/pkg/resource/oci.go @@ -0,0 +1,58 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "context" + + "github.com/ultravioletrs/cocos/pkg/oci" +) + +const ( + // SourceTypeOCIImage represents an OCI image resource source. + SourceTypeOCIImage = "oci-image" +) + +// OCIClient defines the interface for OCI image operations. +type OCIClient interface { + PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error + ToDockerArchive(ctx context.Context, ociDir, destFile string) error +} + +// OCIDownloader adapts OCIClient to the Downloader interface. +// For OCI images, destPath is a directory where the OCI layout is written. +type OCIDownloader struct { + client OCIClient +} + +// NewOCIDownloader creates a new OCI downloader wrapping an OCI client. +func NewOCIDownloader(client OCIClient) *OCIDownloader { + return &OCIDownloader{ + client: client, + } +} + +// Download pulls an OCI image to the destination directory. +// Note: For OCI images, encryption/decryption is handled by Skopeo + CoCo Keyprovider +// transparently via ocicrypt, so this just does the pull. +func (o *OCIDownloader) Download(ctx context.Context, url string, destDir string) error { + source := oci.ResourceSource{ + Type: oci.ResourceTypeOCIImage, + URI: url, + // Encryption handled separately by the caller who sets up Skopeo env. + Encrypted: false, + } + return o.client.PullAndDecrypt(ctx, source, destDir) +} + +// Type returns the source type identifier. +func (o *OCIDownloader) Type() string { + return SourceTypeOCIImage +} + +// Client returns the underlying OCIClient for OCI-specific operations +// like ToDockerArchive that aren't part of the generic Downloader interface. +func (o *OCIDownloader) Client() OCIClient { + return o.client +} diff --git a/pkg/resource/oci_test.go b/pkg/resource/oci_test.go new file mode 100644 index 00000000..c89d0350 --- /dev/null +++ b/pkg/resource/oci_test.go @@ -0,0 +1,57 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/pkg/oci" +) + +type MockOCIClient struct { + mock.Mock +} + +func (m *MockOCIClient) PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error { + args := m.Called(ctx, source, destDir) + return args.Error(0) +} + +func (m *MockOCIClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error { + args := m.Called(ctx, ociDir, destFile) + return args.Error(0) +} + +func TestOCIDownloader(t *testing.T) { + mockClient := new(MockOCIClient) + downloader := NewOCIDownloader(mockClient) + + ctx := context.Background() + url := "docker://example.com/image:latest" + destDir := "/tmp/oci" + + t.Run("Download", func(t *testing.T) { + expectedSource := oci.ResourceSource{ + Type: oci.ResourceTypeOCIImage, + URI: url, + Encrypted: false, + } + mockClient.On("PullAndDecrypt", ctx, expectedSource, destDir).Return(nil).Once() + + err := downloader.Download(ctx, url, destDir) + assert.NoError(t, err) + mockClient.AssertExpectations(t) + }) + + t.Run("Type", func(t *testing.T) { + assert.Equal(t, SourceTypeOCIImage, downloader.Type()) + }) + + t.Run("Client", func(t *testing.T) { + assert.Equal(t, mockClient, downloader.Client()) + }) +} diff --git a/pkg/resource/s3.go b/pkg/resource/s3.go new file mode 100644 index 00000000..360dd18d --- /dev/null +++ b/pkg/resource/s3.go @@ -0,0 +1,161 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "cloud.google.com/go/storage" + "google.golang.org/api/option" +) + +const ( + // SourceTypeS3 represents an S3-compatible object storage source. + SourceTypeS3 = "s3" + // SourceTypeGCS represents a Google Cloud Storage source. + SourceTypeGCS = "gcs" +) + +// S3Downloader downloads resources from S3-compatible object storage. +// It uses the Google Cloud Storage client library with S3-compatible endpoints +// or can be configured for MinIO/AWS S3 via environment variables. +type S3Downloader struct { + endpoint string // Optional custom endpoint for S3-compatible services (e.g., MinIO). +} + +// NewS3Downloader creates a new S3 downloader. +// If endpoint is empty, standard AWS S3 environment credentials/config are used. +func NewS3Downloader(endpoint string) *S3Downloader { + return &S3Downloader{ + endpoint: endpoint, + } +} + +// Download fetches a resource from an S3 URL (s3://bucket/key) and writes it to destPath. +func (s *S3Downloader) Download(ctx context.Context, url string, destPath string) error { + bucket, key, err := parseS3URL(url) + if err != nil { + return err + } + + // Ensure parent directory exists. + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + // Use Google Cloud Storage client with S3-compatible XML API when endpoint is set. + // For standard GCS, use default credentials. + var opts []option.ClientOption + if s.endpoint != "" { + opts = append(opts, option.WithEndpoint(s.endpoint)) + } + + client, err := storage.NewClient(ctx, opts...) + if err != nil { + return fmt.Errorf("failed to create storage client: %w", err) + } + defer client.Close() + + reader, err := client.Bucket(bucket).Object(key).NewReader(ctx) + if err != nil { + return fmt.Errorf("failed to read object %s/%s: %w", bucket, key, err) + } + defer reader.Close() + + f, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer f.Close() + + if _, err := f.ReadFrom(reader); err != nil { + return fmt.Errorf("failed to write object content: %w", err) + } + + return nil +} + +// Type returns the source type identifier. +func (s *S3Downloader) Type() string { + return SourceTypeS3 +} + +// GCSDownloader downloads resources from Google Cloud Storage. +type GCSDownloader struct{} + +// NewGCSDownloader creates a new GCS downloader. +func NewGCSDownloader() *GCSDownloader { + return &GCSDownloader{} +} + +// Download fetches a resource from a GCS URL (gs://bucket/key) and writes it to destPath. +func (g *GCSDownloader) Download(ctx context.Context, url string, destPath string) error { + bucket, key, err := parseGCSURL(url) + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + client, err := storage.NewClient(ctx) + if err != nil { + return fmt.Errorf("failed to create GCS client: %w", err) + } + defer client.Close() + + reader, err := client.Bucket(bucket).Object(key).NewReader(ctx) + if err != nil { + return fmt.Errorf("failed to read object gs://%s/%s: %w", bucket, key, err) + } + defer reader.Close() + + f, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer f.Close() + + if _, err := f.ReadFrom(reader); err != nil { + return fmt.Errorf("failed to write object content: %w", err) + } + + return nil +} + +// Type returns the source type identifier. +func (g *GCSDownloader) Type() string { + return SourceTypeGCS +} + +// parseS3URL parses an S3 URL of the form s3://bucket/key. +func parseS3URL(url string) (bucket, key string, err error) { + if !strings.HasPrefix(url, "s3://") { + return "", "", fmt.Errorf("invalid S3 URL, expected s3://bucket/key, got: %s", url) + } + path := strings.TrimPrefix(url, "s3://") + parts := strings.SplitN(path, "/", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", fmt.Errorf("invalid S3 URL, expected s3://bucket/key, got: %s", url) + } + return parts[0], parts[1], nil +} + +// parseGCSURL parses a GCS URL of the form gs://bucket/key. +func parseGCSURL(url string) (bucket, key string, err error) { + if !strings.HasPrefix(url, "gs://") { + return "", "", fmt.Errorf("invalid GCS URL, expected gs://bucket/key, got: %s", url) + } + path := strings.TrimPrefix(url, "gs://") + parts := strings.SplitN(path, "/", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", fmt.Errorf("invalid GCS URL, expected gs://bucket/key, got: %s", url) + } + return parts[0], parts[1], nil +} diff --git a/pkg/resource/s3_test.go b/pkg/resource/s3_test.go new file mode 100644 index 00000000..1c8298f6 --- /dev/null +++ b/pkg/resource/s3_test.go @@ -0,0 +1,128 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package resource + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseS3URL(t *testing.T) { + tests := []struct { + url string + bucket string + key string + err bool + }{ + {"s3://my-bucket/my-key", "my-bucket", "my-key", false}, + {"s3://my-bucket/path/to/my-key", "my-bucket", "path/to/my-key", false}, + {"s3://my-bucket/", "", "", true}, + {"s3://", "", "", true}, + {"http://my-bucket/my-key", "", "", true}, + {"s3://my-bucket", "", "", true}, + } + + for _, tt := range tests { + bucket, key, err := parseS3URL(tt.url) + if tt.err { + if err == nil { + t.Errorf("expected error for %s, got nil", tt.url) + } + } else { + if err != nil { + t.Errorf("expected no error for %s, got %v", tt.url, err) + } + if bucket != tt.bucket { + t.Errorf("expected bucket %s, got %s", tt.bucket, bucket) + } + if key != tt.key { + t.Errorf("expected key %s, got %s", tt.key, key) + } + } + } +} + +func TestParseGCSURL(t *testing.T) { + tests := []struct { + url string + bucket string + key string + err bool + }{ + {"gs://my-bucket/my-key", "my-bucket", "my-key", false}, + {"gs://my-bucket/path/to/my-key", "my-bucket", "path/to/my-key", false}, + {"gs://my-bucket/", "", "", true}, + {"gs://", "", "", true}, + {"http://my-bucket/my-key", "", "", true}, + {"gs://my-bucket", "", "", true}, + } + + for _, tt := range tests { + bucket, key, err := parseGCSURL(tt.url) + if tt.err { + if err == nil { + t.Errorf("expected error for %s, got nil", tt.url) + } + } else { + if err != nil { + t.Errorf("expected no error for %s, got %v", tt.url, err) + } + if bucket != tt.bucket { + t.Errorf("expected bucket %s, got %s", tt.bucket, bucket) + } + if key != tt.key { + t.Errorf("expected key %s, got %s", tt.key, key) + } + } + } +} + +func TestS3DownloaderErrors(t *testing.T) { + ctx := context.Background() + d := NewS3Downloader("") + + assert.Equal(t, SourceTypeS3, d.Type()) + + t.Run("Invalid URL", func(t *testing.T) { + err := d.Download(ctx, "invalid-url", "dest") + assert.Error(t, err) + }) + + t.Run("Failed to create directory", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "blocked") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + err = d.Download(ctx, "s3://bucket/key", filepath.Join(tmpFile.Name(), "subdir", "file")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create destination directory") + }) +} + +func TestGCSDownloaderErrors(t *testing.T) { + ctx := context.Background() + d := NewGCSDownloader() + + assert.Equal(t, SourceTypeGCS, d.Type()) + + t.Run("Invalid URL", func(t *testing.T) { + err := d.Download(ctx, "invalid-url", "dest") + assert.Error(t, err) + }) + + t.Run("Failed to create directory", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "blocked-gcs") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + err = d.Download(ctx, "gs://bucket/key", filepath.Join(tmpFile.Name(), "subdir", "file")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create destination directory") + }) +} diff --git a/test/cvms/main.go b/test/cvms/main.go index ecdf0e6e..fd0453a5 100644 --- a/test/cvms/main.go +++ b/test/cvms/main.go @@ -45,8 +45,10 @@ var ( // Remote resource configuration. kbsURL string algoSourceURL string + algoSourceType string algoKBSResourcePath string datasetSourceURLs string + datasetSourceType string datasetKBSPaths string algoType string algoArgsString string @@ -118,12 +120,16 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se } for i := 0; i < len(datasetURLs); i++ { + srcType := datasetSourceType + if srcType == "" { + srcType = "oci-image" + } datasets = append(datasets, &cvms.Dataset{ Hash: dataHashBytes, UserKey: pubPem.Bytes, Filename: fmt.Sprintf("dataset_%d.csv", i), Source: &cvms.Source{ - Type: "oci-image", + Type: srcType, Url: datasetURLs[i], KbsResourcePath: datasetKBSPathsList[i], Encrypted: datasetKBSPathsList[i] != "", @@ -177,13 +183,20 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se algoArgs = strings.Split(algoArgsString, ",") } + var algoSrcType string + if algoSourceType != "" { + algoSrcType = algoSourceType + } else { + algoSrcType = "oci-image" + } + algorithm = &cvms.Algorithm{ Hash: algoHashBytes, UserKey: pubPem.Bytes, AlgoType: algoType, AlgoArgs: algoArgs, Source: &cvms.Source{ - Type: "oci-image", + Type: algoSrcType, Url: algoSourceURL, KbsResourcePath: algoKBSResourcePath, Encrypted: algoKBSResourcePath != "", @@ -259,14 +272,16 @@ func main() { flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path") // Remote resource flags flagSet.StringVar(&kbsURL, "kbs-url", "", "KBS endpoint URL (e.g., 'http://localhost:8080')") - flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (s3://bucket/key or https://...)") + flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (docker://..., s3://..., https://..., etc.)") + flagSet.StringVar(&algoSourceType, "algo-source-type", "", "Algorithm source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.") flagSet.StringVar(&algoKBSResourcePath, "algo-kbs-path", "", "Algorithm KBS resource path (e.g., 'default/key/algo-key')") flagSet.StringVar(&datasetSourceURLs, "dataset-source-urls", "", "Dataset source URLs, comma-separated") + flagSet.StringVar(&datasetSourceType, "dataset-source-type", "", "Dataset source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.") flagSet.StringVar(&datasetKBSPaths, "dataset-kbs-paths", "", "Dataset KBS resource paths, comma-separated") flagSet.StringVar(&algoType, "algo-type", "", "Algorithm execution type (e.g. binary, python)") flagSet.StringVar(&algoArgsString, "algo-args", "", "Algorithm arguments, comma-separated") flagSet.StringVar(&algoHash, "algo-hash", "", "Algorithm SHA256 hash (hex string)") - flagSet.StringVar(&datasetTypeString, "dataset-type", "", "Dataset source type, comma-separated (deprecated, always oci-image)") + flagSet.StringVar(&datasetTypeString, "dataset-type", "", "Dataset source type (deprecated, use --dataset-source-type)") flagSet.StringVar(&datasetHash, "dataset-hash", "", "Dataset SHA256 hash (hex string)") flagSet.StringVar(&datasetDecompress, "dataset-decompress", "", "Dataset decompression bools, comma-separated (e.g. true,false)")