Files
cocos/agent/resource_test.go
T
Sammy Kerata Oina c59a413765
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
NOISSUE - Implement extensible resource downloader framework with support for S3, GCS, and OCI sources (#590)
* feat: implement extensible resource downloader framework with support for S3, GCS, and OCI sources

Signed-off-by: SammyOina <sammyoina@gmail.com>

* refactor: improve resource URL parsing and add support for bare OCI image references

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: add empty string check and slash requirement for OCI image inference, and update python unit tests with event mock expectations

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: introduce OCIClient interface, add test coverage for decryption, and improve resource download error handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: remove trailing whitespace in OCI downloader and HTTP tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: SammyOina <sammyoina@gmail.com>
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2026-04-28 11:21:03 +02:00

189 lines
5.1 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package agent
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/pkg/resource"
)
type MockDownloader struct {
mock.Mock
}
func (m *MockDownloader) Download(ctx context.Context, url string, destPath string) error {
args := m.Called(ctx, url, destPath)
if args.Error(0) == nil {
// Simulate writing to destPath if it's a success
content := "mock content"
if len(args) > 1 {
if c, ok := args.Get(1).(string); ok {
content = c
}
}
_ = os.MkdirAll(filepath.Dir(destPath), 0o755)
_ = os.WriteFile(destPath, []byte(content), 0o644)
}
return args.Error(0)
}
func (m *MockDownloader) Type() string {
return m.Called().String(0)
}
func TestDownloadAndDecryptGenericResource(t *testing.T) {
registry := resource.NewRegistry()
mockDownloader := new(MockDownloader)
mockDownloader.On("Type").Return(resource.SourceTypeHTTP)
registry.Register(mockDownloader)
svc := &agentService{
logger: slog.Default(),
resourceRegistry: registry,
computation: Computation{
KBS: KBSConfig{
Enabled: true,
URL: "http://mock-kbs",
},
},
}
ctx := context.Background()
t.Run("Successful download without encryption", func(t *testing.T) {
source := &ResourceSource{
URL: "http://example.com/resource",
}
destPath := filepath.Join(os.TempDir(), "cocos-resources", "algo", "resource")
mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, "some data").Once()
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "algo")
assert.NoError(t, err)
assert.Equal(t, []byte("some data"), res.Data)
mockDownloader.AssertExpectations(t)
})
t.Run("Successful download with encryption", func(t *testing.T) {
key := make([]byte, 32)
_, _ = io.ReadFull(rand.Reader, key)
plaintext := []byte("secret data")
block, _ := aes.NewCipher(key)
gcm, _ := cipher.NewGCM(block)
nonce := make([]byte, gcm.NonceSize())
_, _ = io.ReadFull(rand.Reader, nonce)
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
// Mock KBS
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(key)
}))
defer ts.Close()
svc.computation.KBS.URL = ts.URL
source := &ResourceSource{
URL: "http://example.com/encrypted",
Encrypted: true,
KBSResourcePath: "keys/1",
}
destPath := filepath.Join(os.TempDir(), "cocos-resources", "data", "encrypted")
mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, string(ciphertext)).Once()
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "data")
assert.NoError(t, err)
assert.Equal(t, plaintext, res.Data)
mockDownloader.AssertExpectations(t)
})
t.Run("Registry not initialized", func(t *testing.T) {
badSvc := &agentService{logger: slog.Default()}
_, err := badSvc.downloadAndDecryptGenericResource(ctx, &ResourceSource{}, "http", "algo")
assert.Error(t, err)
assert.Contains(t, err.Error(), "resource registry not initialized")
})
}
func TestGetKeyFromKBS(t *testing.T) {
svc := &agentService{
logger: slog.Default(),
computation: Computation{
KBS: KBSConfig{
Enabled: true,
},
},
}
ctx := context.Background()
t.Run("KBS disabled", func(t *testing.T) {
svc.computation.KBS.Enabled = false
_, err := svc.getKeyFromKBS(ctx, "path")
assert.Error(t, err)
})
t.Run("Successful fetch", func(t *testing.T) {
svc.computation.KBS.Enabled = true
key := []byte("this is a 32-byte key!!!!!!!!!!!")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Contains(t, r.URL.Path, "resource/path")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(key)
}))
defer ts.Close()
svc.computation.KBS.URL = ts.URL
fetched, err := svc.getKeyFromKBS(ctx, "path")
assert.NoError(t, err)
assert.Equal(t, key, fetched)
})
t.Run("KBS error", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
svc.computation.KBS.URL = ts.URL
_, err := svc.getKeyFromKBS(ctx, "path")
assert.Error(t, err)
})
}
func TestInferSourceTypeDetailed(t *testing.T) {
tests := []struct {
url string
expected string
}{
{"s3://bucket/key", resource.SourceTypeS3},
{"gs://bucket/key", resource.SourceTypeGCS},
{"https://example.com/file", resource.SourceTypeHTTPS},
{"http://example.com/file", resource.SourceTypeHTTP},
{"docker://ubuntu", resource.SourceTypeOCIImage},
{"oci:/path/to/dir", resource.SourceTypeOCIImage},
{"ubuntu:latest", resource.SourceTypeOCIImage},
{"myregistry.io/myimage:tag", resource.SourceTypeOCIImage},
{"invalid-url-no-slash", ""},
{"", ""},
{"ftp://server/file", ""},
}
for _, tt := range tests {
assert.Equal(t, tt.expected, inferSourceType(tt.url), "URL: %s", tt.url)
}
}