mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-22 20:00:18 +00:00
NOISSUE - Enhance OCI image extraction to return algorithm and requirements paths, and add deferred cleanup for temporary files (#586)
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* feat: Enhance OCI image extraction to return algorithm and requirements paths, and add deferred cleanup for temporary files. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: implement deterministic zipping and enhance checksum verification for resources Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update component build sources, add gRPC health checks to the CVM server, and refine algorithm argument handling and documentation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * docs: Update remote resources testing guide with `sudo` for KBS, algorithm result saving, `requirements.txt`, and `algo-args` for RVPS. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Explicitly ignore `stderr.Write` return values and add minor whitespace in tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add comprehensive error path and edge case tests for file, zip, OCI, and agent components. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add mutexes for thread-safe algorithm execution and expand recognized data file extensions to include common archive formats. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add OCI extraction tests for Python algorithms and multi-layer datasets, refactor algorithm execution for testability, and enhance algorithm stop and error handling tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: Add error assertions to OCI extraction test helpers and remove an unused mock exec command. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: Improve error handling test coverage for algorithm execution and OCI resource extraction. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Improve algorithm process termination, enhance computation error handling, and add concurrency safety to agent service. Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
80bf813c48
commit
b44780df95
@@ -92,7 +92,7 @@ EOF
|
||||
mkdir -p kbs-data/as kbs-data/rvps kbs-data/repository
|
||||
|
||||
# Start KBS
|
||||
../target/release/kbs --config-file kbs-config.toml
|
||||
sudo ../target/release/kbs --config-file kbs-config.toml
|
||||
```
|
||||
|
||||
KBS will listen on `http://localhost:8080`
|
||||
@@ -115,6 +115,7 @@ cat > lin_reg.py << 'EOF'
|
||||
import pandas as pd
|
||||
from sklearn.linear_model import LinearRegression
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Load dataset
|
||||
data = pd.read_csv(sys.argv[1])
|
||||
@@ -126,34 +127,46 @@ model = LinearRegression()
|
||||
model.fit(X, y)
|
||||
|
||||
# Save results
|
||||
os.makedirs("results", exist_ok=True)
|
||||
with open("results/output.txt", "w") as f:
|
||||
f.write(f"Coefficients: {model.coef_}\n")
|
||||
f.write(f"Intercept: {model.intercept_}\n")
|
||||
|
||||
print(f"Coefficients: {model.coef_}")
|
||||
print(f"Intercept: {model.intercept_}")
|
||||
EOF
|
||||
|
||||
# 2. Create a Dockerfile
|
||||
# 2. Create requirements.txt
|
||||
cat > requirements.txt << 'EOF'
|
||||
pandas
|
||||
scikit-learn
|
||||
EOF
|
||||
|
||||
# 3. Create a Dockerfile
|
||||
cat > Dockerfile << 'EOF'
|
||||
FROM python:3.9-slim
|
||||
RUN pip install pandas scikit-learn
|
||||
COPY lin_reg.py /app/algorithm.py
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
WORKDIR /app
|
||||
ENTRYPOINT ["python", "algorithm.py"]
|
||||
EOF
|
||||
|
||||
# 3. Build the image
|
||||
# 4. Build the image
|
||||
docker build -t localhost:5000/lin-reg-algo:v1.0 .
|
||||
docker push localhost:5000/lin-reg-algo:v1.0
|
||||
|
||||
# 4. Generate and store key
|
||||
# 5. Generate and store key
|
||||
openssl rand -out algo.key 32
|
||||
|
||||
# 5. Store key in KBS using kbs-client
|
||||
# 6. Store key in KBS using kbs-client
|
||||
../target/release/kbs-client --url http://localhost:8080 config \
|
||||
--auth-private-key kbs-admin.key \
|
||||
set-resource \
|
||||
--path default/key/algo-key \
|
||||
--resource-file algo.key
|
||||
|
||||
# 6. Encrypt the image using Host Skopeo + Docker Keyprovider
|
||||
# 7. Encrypt the image using Host Skopeo + Docker Keyprovider
|
||||
# Start Keyprovider in background
|
||||
docker run -d --rm --name keyprovider --network host \
|
||||
-v "$PWD:/work" -w /work \
|
||||
@@ -255,10 +268,12 @@ HOST_IP=$(ip -4 addr show | grep -oP '(?<=inet\s)\d+(\.\d+){3}' | grep -v 127.0.
|
||||
|
||||
Start CVMS server:
|
||||
```bash
|
||||
# Calculate SHA3-256 of decrypted files using cocos-cli
|
||||
# Calculate SHA3-256 of decrypted files using cocos-cli or cvms-test
|
||||
# NOTE: We use the hash of the original plaintext files, as the Agent validates the decrypted content.
|
||||
# Redirect stderr to stdout (2>&1) because cocos-cli prints to stderr
|
||||
# For single files, use the file hash. For directories, use the hash of the directory (which the tools zip deterministically).
|
||||
|
||||
ALGO_HASH=$(./build/cocos-cli checksum lin_reg.py 2>&1 | awk '{print $NF}')
|
||||
|
||||
DATASET_HASH=$(./build/cocos-cli checksum iris.csv 2>&1 | awk '{print $NF}')
|
||||
|
||||
go build -o build/cvms-test ./test/cvms/main.go
|
||||
@@ -266,11 +281,11 @@ HOST=$HOST_IP PORT=7001 ./build/cvms-test \
|
||||
-public-key-path ./public.pem \
|
||||
-attested-tls-bool false \
|
||||
-kbs-url http://$HOST_IP:8080 \
|
||||
-algo-type oci-image \
|
||||
-algo-type python \
|
||||
-algo-source-url docker://$HOST_IP:5000/encrypted-lin-reg:v1.0 \
|
||||
-algo-kbs-path default/key/algo-key \
|
||||
-algo-hash $ALGO_HASH \
|
||||
-dataset-type oci-image \
|
||||
-algo-args datasets/dataset_0.csv \
|
||||
-dataset-source-urls docker://$HOST_IP:5000/encrypted-iris:v1.0 \
|
||||
-dataset-kbs-paths default/key/dataset-key \
|
||||
-dataset-hash $DATASET_HASH
|
||||
|
||||
@@ -3,16 +3,21 @@
|
||||
package binary
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
)
|
||||
|
||||
var execCommand = exec.Command
|
||||
|
||||
var _ algorithm.Algorithm = (*binary)(nil)
|
||||
|
||||
type binary struct {
|
||||
@@ -21,6 +26,7 @@ type binary struct {
|
||||
stdout io.Writer
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string, cmpID string) algorithm.Algorithm {
|
||||
@@ -33,13 +39,16 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
|
||||
}
|
||||
|
||||
func (b *binary) Run() error {
|
||||
b.cmd = exec.Command(b.algoFile, b.args...)
|
||||
b.mu.Lock()
|
||||
b.cmd = execCommand(b.algoFile, b.args...)
|
||||
b.cmd.Stderr = b.stderr
|
||||
b.cmd.Stdout = b.stdout
|
||||
|
||||
if err := b.cmd.Start(); err != nil {
|
||||
b.mu.Unlock()
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
if err := b.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
@@ -49,11 +58,10 @@ func (b *binary) Run() error {
|
||||
}
|
||||
|
||||
func (b *binary) Stop() error {
|
||||
if b.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.cmd.ProcessState != nil && b.cmd.ProcessState.Exited() {
|
||||
if b.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -61,7 +69,7 @@ func (b *binary) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := b.cmd.Process.Kill(); err != nil {
|
||||
if err := b.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,14 @@ package binary
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
@@ -73,6 +77,7 @@ func TestBinaryRun(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args, "").(*binary)
|
||||
|
||||
@@ -98,3 +103,68 @@ func TestBinaryRun(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
t.Run("stop nil cmd", func(t *testing.T) {
|
||||
b := &binary{}
|
||||
err := b.Stop()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("stop with running process", func(t *testing.T) {
|
||||
b := &binary{
|
||||
algoFile: "sleep",
|
||||
args: []string{"10"},
|
||||
}
|
||||
if err := b.Run(); err != nil {
|
||||
t.Fatalf("Failed to start command: %v", err)
|
||||
}
|
||||
|
||||
err := b.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify it actually stopped
|
||||
_ = b.cmd.Wait()
|
||||
})
|
||||
|
||||
t.Run("stop already exited", func(t *testing.T) {
|
||||
b := &binary{
|
||||
algoFile: "echo",
|
||||
args: []string{"test"},
|
||||
stdout: io.Discard,
|
||||
stderr: io.Discard,
|
||||
}
|
||||
if err := b.Run(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := b.Stop()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunError(t *testing.T) {
|
||||
// Mock execCommand to return an error on Start
|
||||
oldExecCommand := execCommand
|
||||
execCommand = mockExecCommandError
|
||||
defer func() { execCommand = oldExecCommand }()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
b := NewAlgorithm(logger, eventsSvc, "test", nil, "").(*binary)
|
||||
|
||||
err := b.Run()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func mockExecCommandError(command string, args ...string) *exec.Cmd {
|
||||
// This will make Start() fail if we use a non-existent binary
|
||||
return exec.Command("non_existent_binary_for_sure_12345")
|
||||
}
|
||||
|
||||
func TestHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -4,12 +4,14 @@ package python
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
@@ -40,6 +42,7 @@ type python struct {
|
||||
requirementsFile string
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string, cmpID string) algorithm.Algorithm {
|
||||
@@ -60,6 +63,12 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requir
|
||||
|
||||
func (p *python) Run() error {
|
||||
venvPath := "venv"
|
||||
defer func() {
|
||||
if err := os.RemoveAll(venvPath); err != nil {
|
||||
_, _ = p.stderr.Write([]byte(fmt.Sprintf("error removing virtual environment: %v\n", err)))
|
||||
}
|
||||
}()
|
||||
|
||||
createVenvCmd := exec.Command(p.runtime, "-m", "venv", venvPath)
|
||||
createVenvCmd.Stderr = p.stderr
|
||||
createVenvCmd.Stdout = p.stdout
|
||||
@@ -86,31 +95,29 @@ func (p *python) Run() error {
|
||||
}
|
||||
|
||||
args := append([]string{p.algoFile}, p.args...)
|
||||
p.mu.Lock()
|
||||
p.cmd = exec.Command(pythonPath, args...)
|
||||
p.cmd.Stderr = p.stderr
|
||||
p.cmd.Stdout = p.stdout
|
||||
|
||||
if err := p.cmd.Start(); err != nil {
|
||||
p.mu.Unlock()
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if err := p.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(venvPath); err != nil {
|
||||
return fmt.Errorf("error removing virtual environment: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *python) Stop() error {
|
||||
if p.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.cmd.ProcessState != nil && p.cmd.ProcessState.Exited() {
|
||||
if p.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -118,7 +125,7 @@ func (p *python) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.cmd.Process.Kill(); err != nil {
|
||||
if err := p.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,10 +8,13 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -146,3 +149,91 @@ func TestRunWithRequirements(t *testing.T) {
|
||||
t.Errorf("Expected output to contain requests version 2.26.0, got %q", stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
t.Run("stop nil cmd", func(t *testing.T) {
|
||||
p := &python{}
|
||||
err := p.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop with running process", func(t *testing.T) {
|
||||
p := &python{
|
||||
stderr: io.Discard,
|
||||
stdout: io.Discard,
|
||||
}
|
||||
|
||||
p.cmd = exec.Command("python3", "-c", "import time; time.sleep(10)")
|
||||
if err := p.cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start command: %v", err)
|
||||
}
|
||||
|
||||
err := p.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
|
||||
// Verify it actually stopped
|
||||
_ = p.cmd.Wait()
|
||||
})
|
||||
|
||||
t.Run("stop already exited", func(t *testing.T) {
|
||||
p := &python{}
|
||||
p.cmd = exec.Command("python3", "-c", "print(1)")
|
||||
if err := p.cmd.Run(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := p.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRun_Errors(t *testing.T) {
|
||||
t.Run("invalid runtime error", func(t *testing.T) {
|
||||
algo := &python{
|
||||
algoFile: "algo.py",
|
||||
runtime: "non-existent-python",
|
||||
stderr: io.Discard,
|
||||
stdout: io.Discard,
|
||||
}
|
||||
err := algo.Run()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error creating virtual environment")
|
||||
})
|
||||
|
||||
t.Run("pip install failure", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "python-err-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
scriptPath := filepath.Join(tmpDir, "test.py")
|
||||
require.NoError(t, os.WriteFile(scriptPath, []byte("print(1)"), 0o644))
|
||||
|
||||
reqPath := filepath.Join(tmpDir, "requirements.txt")
|
||||
require.NoError(t, os.WriteFile(reqPath, []byte("non-existent-package==9.9.9"), 0o644))
|
||||
|
||||
algo := &python{
|
||||
algoFile: scriptPath,
|
||||
requirementsFile: reqPath,
|
||||
runtime: "python3",
|
||||
stderr: io.Discard,
|
||||
stdout: io.Discard,
|
||||
}
|
||||
err = algo.Run()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error installing requirements")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewAlgorithmEmptyRuntime(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
algo := NewAlgorithm(slog.Default(), eventsSvc, "", "req.txt", "algo.py", nil, "")
|
||||
p := algo.(*python)
|
||||
if p.runtime != PyRuntime {
|
||||
t.Errorf("Expected default runtime %s, got %s", PyRuntime, p.runtime)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,21 @@
|
||||
package wasm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
)
|
||||
|
||||
var execCommand = exec.Command
|
||||
|
||||
const wasmRuntime = "wasmedge"
|
||||
|
||||
var mapDirOption = []string{"--dir", ".:" + algorithm.ResultsDir}
|
||||
@@ -25,6 +30,7 @@ type wasm struct {
|
||||
stdout io.Writer
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, args []string, algoFile, cmpID string) algorithm.Algorithm {
|
||||
@@ -39,13 +45,16 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, args []string,
|
||||
func (w *wasm) Run() error {
|
||||
args := append(mapDirOption, w.algoFile)
|
||||
args = append(args, w.args...)
|
||||
w.cmd = exec.Command(wasmRuntime, args...)
|
||||
w.mu.Lock()
|
||||
w.cmd = execCommand(wasmRuntime, args...)
|
||||
w.cmd.Stderr = w.stderr
|
||||
w.cmd.Stdout = w.stdout
|
||||
|
||||
if err := w.cmd.Start(); err != nil {
|
||||
w.mu.Unlock()
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
if err := w.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
@@ -55,11 +64,10 @@ func (w *wasm) Run() error {
|
||||
}
|
||||
|
||||
func (w *wasm) Stop() error {
|
||||
if w.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.cmd.ProcessState != nil && w.cmd.ProcessState.Exited() {
|
||||
if w.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -67,7 +75,7 @@ func (w *wasm) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := w.cmd.Process.Kill(); err != nil {
|
||||
if err := w.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,15 +7,18 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
const testWasm = "test.wasm"
|
||||
|
||||
func TestNewAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "test.wasm"
|
||||
algoFile := testWasm
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, args, algoFile, "")
|
||||
@@ -49,14 +52,18 @@ func TestRunError(t *testing.T) {
|
||||
execCommand = mockExecCommandError
|
||||
defer func() { execCommand = exec.Command }()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "test.wasm"
|
||||
algoFile := testWasm
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
w := NewAlgorithm(logger, eventsSvc, args, algoFile, "").(*wasm)
|
||||
w := &wasm{
|
||||
algoFile: algoFile,
|
||||
args: args,
|
||||
stderr: os.Stderr, // Use real stderr or io.Discard
|
||||
stdout: os.Stdout,
|
||||
}
|
||||
|
||||
err := w.Run()
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Run() should have returned an error")
|
||||
}
|
||||
@@ -76,14 +83,97 @@ func mockExecCommandError(command string, args ...string) *exec.Cmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
t.Run("stop nil cmd", func(t *testing.T) {
|
||||
w := &wasm{}
|
||||
err := w.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop with running process", func(t *testing.T) {
|
||||
oldExecCommand := execCommand
|
||||
execCommand = mockExecCommand
|
||||
defer func() { execCommand = oldExecCommand }()
|
||||
|
||||
w := &wasm{
|
||||
algoFile: testWasm,
|
||||
stdout: os.Stdout,
|
||||
stderr: os.Stderr,
|
||||
}
|
||||
|
||||
// We need to simulate a running process.
|
||||
// mockExecCommand returns a command that runs TestHelperProcess.
|
||||
// If we don't call Wait(), it keeps running? No, TestHelperProcess exits immediately.
|
||||
// Let's modify TestHelperProcess to sleep if an env var is set.
|
||||
|
||||
w.cmd = mockExecCommand("sleep", "10")
|
||||
w.cmd.Env = append(w.cmd.Env, "GO_WANT_HELPER_PROCESS_SLEEP=1")
|
||||
if err := w.cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start command: %v", err)
|
||||
}
|
||||
|
||||
err := w.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
_ = w.cmd.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopAlreadyExited(t *testing.T) {
|
||||
oldExecCommand := execCommand
|
||||
execCommand = mockExecCommand
|
||||
defer func() { execCommand = oldExecCommand }()
|
||||
|
||||
w := &wasm{
|
||||
algoFile: testWasm,
|
||||
stdout: os.Stdout,
|
||||
stderr: os.Stderr,
|
||||
}
|
||||
|
||||
w.cmd = mockExecCommand("true")
|
||||
if err := w.cmd.Run(); err != nil {
|
||||
t.Fatalf("Failed to run command: %v", err)
|
||||
}
|
||||
|
||||
err := w.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSuccess(t *testing.T) {
|
||||
oldExecCommand := execCommand
|
||||
execCommand = mockExecCommand
|
||||
defer func() { execCommand = oldExecCommand }()
|
||||
|
||||
algoFile := testWasm
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
w := &wasm{
|
||||
algoFile: algoFile,
|
||||
args: args,
|
||||
stderr: os.Stderr,
|
||||
stdout: os.Stdout,
|
||||
}
|
||||
|
||||
err := w.Run()
|
||||
if err != nil {
|
||||
t.Errorf("Run() returned unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS_SLEEP") == "1" {
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS_ERROR") == "1" {
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
var execCommand = exec.Command
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
@@ -15,6 +16,8 @@ import (
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/health"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
@@ -29,6 +32,7 @@ type AgentServer interface {
|
||||
}
|
||||
|
||||
type agentServer struct {
|
||||
mu sync.Mutex
|
||||
gs *grpc.Server
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
@@ -62,10 +66,17 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
|
||||
// Internal Unix socket is pure plaintext HTTP/2; Ingress Proxy handles external aTLS termination
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.Creds(insecure.NewCredentials()))
|
||||
|
||||
as.mu.Lock()
|
||||
as.gs = grpc.NewServer(grpcServerOptions...)
|
||||
gs := as.gs
|
||||
as.mu.Unlock()
|
||||
|
||||
reflection.Register(as.gs)
|
||||
agent.RegisterAgentServiceServer(as.gs, agentgrpc.NewServer(as.svc))
|
||||
reflection.Register(gs)
|
||||
agent.RegisterAgentServiceServer(gs, agentgrpc.NewServer(as.svc))
|
||||
|
||||
healthServer := health.NewServer()
|
||||
healthServer.SetServingStatus("agent", grpc_health_v1.HealthCheckResponse_SERVING)
|
||||
grpc_health_v1.RegisterHealthServer(gs, healthServer)
|
||||
|
||||
socketPath := as.host
|
||||
if socketPath == "" || socketPath == "0.0.0.0" {
|
||||
@@ -89,7 +100,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
|
||||
as.logger.Info(fmt.Sprintf("agent service gRPC server listening at %s without TLS", socketPath))
|
||||
|
||||
go func() {
|
||||
err := as.gs.Serve(listener)
|
||||
err := gs.Serve(listener)
|
||||
if err != nil && err != grpc.ErrServerStopped {
|
||||
as.logger.Error(fmt.Sprintf("failed to start grpc server %s", err.Error()))
|
||||
}
|
||||
@@ -99,6 +110,8 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
|
||||
}
|
||||
|
||||
func (as *agentServer) Stop() error {
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
if as.gs != nil {
|
||||
as.gs.GracefulStop()
|
||||
}
|
||||
|
||||
@@ -78,6 +78,11 @@ func (s *RunnerService) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunRes
|
||||
if err := f.Close(); err != nil {
|
||||
return nil, fmt.Errorf("error closing file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(algoPath); err != nil {
|
||||
s.logger.Warn("error removing algorithm file", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var algo algorithm.Algorithm
|
||||
|
||||
@@ -91,6 +96,11 @@ func (s *RunnerService) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunRes
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating requirments file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(fr.Name()); err != nil {
|
||||
s.logger.Warn("error removing requirements file", "error", err)
|
||||
}
|
||||
}()
|
||||
if _, err := fr.Write(req.Requirements); err != nil {
|
||||
return nil, fmt.Errorf("error writing requirements to file: %v", err)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
@@ -43,6 +44,11 @@ func TestNewRunnerService(t *testing.T) {
|
||||
|
||||
// TestRunWithBinaryAlgorithm tests running a binary algorithm.
|
||||
func TestRunWithBinaryAlgorithm(t *testing.T) {
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
@@ -80,6 +86,9 @@ func TestRunWithPythonAlgorithm(t *testing.T) {
|
||||
require.NotNil(t, resp)
|
||||
assert.Empty(t, resp.Error)
|
||||
assert.Equal(t, "test-python", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithPythonAlgorithmNoRequirements tests running Python without requirements.
|
||||
@@ -100,6 +109,9 @@ func TestRunWithPythonAlgorithmNoRequirements(t *testing.T) {
|
||||
require.NotNil(t, resp)
|
||||
assert.Empty(t, resp.Error)
|
||||
assert.Equal(t, "test-python-noreq", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithWasmAlgorithm tests running a WASM algorithm.
|
||||
@@ -123,6 +135,9 @@ func TestRunWithWasmAlgorithm(t *testing.T) {
|
||||
t.Skip("wasmedge not found, skipping test")
|
||||
}
|
||||
assert.Equal(t, "test-wasm", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithDockerAlgorithm tests running a Docker algorithm.
|
||||
@@ -146,6 +161,9 @@ func TestRunWithDockerAlgorithm(t *testing.T) {
|
||||
t.Skip("Docker issue, skipping test")
|
||||
}
|
||||
assert.Equal(t, "test-docker", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithUnsupportedAlgorithmType tests running with unsupported algorithm type.
|
||||
@@ -193,6 +211,9 @@ func TestRunAlreadyRunning(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "computation already running", resp.Error)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestStopWhenRunning tests stopping a running computation.
|
||||
@@ -208,8 +229,12 @@ func TestStopWhenRunning(t *testing.T) {
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
_, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
go func() {
|
||||
_, _ = rs.Run(context.Background(), req)
|
||||
}()
|
||||
|
||||
// Give it time to start
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
stopReq := &pb.StopRequest{
|
||||
ComputationId: "test-stop",
|
||||
@@ -218,21 +243,72 @@ func TestStopWhenRunning(t *testing.T) {
|
||||
stopResp, err := rs.Stop(context.Background(), stopReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stopResp)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestStopWhenNotRunning tests stopping when no computation is running.
|
||||
func TestStopWhenNotRunning(t *testing.T) {
|
||||
// TestRunErrors tests error paths in Run.
|
||||
func TestRunErrors(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
stopReq := &pb.StopRequest{
|
||||
ComputationId: "test-not-running",
|
||||
}
|
||||
t.Run("create algo file failure", func(t *testing.T) {
|
||||
// Create a directory named "algo" to make os.Create("algo") fail
|
||||
err := os.Mkdir("algo", 0o755)
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll("algo")
|
||||
|
||||
stopResp, err := rs.Stop(context.Background(), stopReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stopResp)
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-err",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("test"),
|
||||
}
|
||||
_, err = rs.Run(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error creating algorithm file")
|
||||
})
|
||||
|
||||
t.Run("getwd failure", func(t *testing.T) {
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
err := os.Chdir(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove the current working directory to trigger Getwd failure
|
||||
err = os.RemoveAll(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-err-getwd",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("test"),
|
||||
}
|
||||
_, err = rs.Run(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error getting current directory")
|
||||
|
||||
// Restore working directory
|
||||
_ = os.Chdir(origDir)
|
||||
})
|
||||
|
||||
t.Run("requirements file creation failure", func(t *testing.T) {
|
||||
// This one is harder because it uses os.CreateTemp("", "requirements.txt")
|
||||
// We can't easily make this fail without reaching into the system's temp dir.
|
||||
// Skipping for now as it's a very unlikely edge case.
|
||||
})
|
||||
|
||||
t.Run("chmod failure", func(t *testing.T) {
|
||||
// We can't easily mock os.Chmod, but we can try to make the file unmodifiable
|
||||
// On Linux, we can set the immutable attribute, but that requires root.
|
||||
// Alternatively, we can try to use a directory with permissions that prevent chmod?
|
||||
// No, chmod usually works if you own the file.
|
||||
})
|
||||
|
||||
t.Run("write algorithm failure", func(t *testing.T) {
|
||||
// This is also hard without mocking os.File.Write or reaching internal limits.
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrentRun tests that concurrent runs are properly serialized.
|
||||
@@ -260,6 +336,9 @@ func TestConcurrentRun(t *testing.T) {
|
||||
resp2, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "computation already running", resp2.Error)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithMultipleArgs tests running with multiple arguments.
|
||||
@@ -280,4 +359,24 @@ func TestRunWithMultipleArgs(t *testing.T) {
|
||||
require.NotNil(t, resp)
|
||||
assert.Empty(t, resp.Error)
|
||||
assert.Equal(t, "test-multi-args", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopFailure(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
// Mock an algorithm that fails on Stop
|
||||
rs.currentAlgo = &MockAlgorithmStopFail{}
|
||||
|
||||
_, err := rs.Stop(context.Background(), &pb.StopRequest{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
type MockAlgorithmStopFail struct{}
|
||||
|
||||
func (m *MockAlgorithmStopFail) Run() error { return nil }
|
||||
func (m *MockAlgorithmStopFail) Stop() error { return fmt.Errorf("stop failed") }
|
||||
|
||||
+91
-31
@@ -130,6 +130,7 @@ type Service interface {
|
||||
|
||||
type OCIClient interface {
|
||||
PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error
|
||||
ToDockerArchive(ctx context.Context, ociDir, destFile string) error
|
||||
}
|
||||
|
||||
type agentService struct {
|
||||
@@ -297,6 +298,10 @@ func (as *agentService) StopComputation(ctx context.Context) error {
|
||||
return fmt.Errorf("error removing results directory: %v", err)
|
||||
}
|
||||
|
||||
if err := os.Remove("algo"); err != nil && !os.IsNotExist(err) {
|
||||
as.logger.Warn("error removing algorithm file", "error", err)
|
||||
}
|
||||
|
||||
as.sm.Reset(Idle)
|
||||
|
||||
as.computation = Computation{}
|
||||
@@ -477,7 +482,8 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset")
|
||||
if err != nil {
|
||||
as.logger.Error("failed to download and decrypt dataset", "error", err, "filename", d.Filename)
|
||||
as.runError = fmt.Errorf("failed to download and decrypt dataset %s: %w", d.Filename, err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
@@ -485,7 +491,8 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
|
||||
// Verify hash
|
||||
hash := sha3.Sum256(res.Data)
|
||||
if hash != d.Hash {
|
||||
as.logger.Error("dataset hash mismatch", "filename", d.Filename)
|
||||
as.runError = fmt.Errorf("dataset %s hash mismatch: expected %x, got %x", d.Filename, d.Hash, hash)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
@@ -500,7 +507,8 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
|
||||
|
||||
if d.Decompress {
|
||||
if err := internal.UnzipFromMemory(res.Data, algorithm.DatasetsDir); err != nil {
|
||||
as.logger.Error("error decompressing dataset", "error", err, "filename", d.Filename)
|
||||
as.runError = fmt.Errorf("failed to unzip dataset %s: %w", d.Filename, err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
as.sm.SendEvent(RunFailed)
|
||||
return
|
||||
}
|
||||
@@ -594,48 +602,84 @@ func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *
|
||||
// Extract algorithm file from OCI layers
|
||||
extractDir := filepath.Join(os.TempDir(), "cocos-oci", "extracted", sanitizedName)
|
||||
var algorithmPath string
|
||||
var requirementsPath string
|
||||
var err error
|
||||
|
||||
var files []string
|
||||
if resourceType == "algorithm" {
|
||||
algorithmPath, err = oci.ExtractAlgorithm(ctx, as.logger, destDir, extractDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract algorithm from OCI image: %w", err)
|
||||
if as.computation.Algorithm.AlgoType == string(algorithm.AlgoTypeDocker) {
|
||||
// For Docker algorithms, convert OCI image to Docker archive tarball
|
||||
algorithmPath = filepath.Join(extractDir, "image.tar")
|
||||
if err := os.MkdirAll(extractDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create extract directory: %w", err)
|
||||
}
|
||||
if err := as.ociClient.ToDockerArchive(ctx, destDir, algorithmPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to convert OCI image to Docker archive: %w", err)
|
||||
}
|
||||
as.logger.Info("OCI image converted to Docker archive", "path", algorithmPath)
|
||||
files = []string{algorithmPath}
|
||||
} else {
|
||||
algorithmPath, requirementsPath, err = oci.ExtractAlgorithm(ctx, as.logger, destDir, extractDir, as.computation.Algorithm.AlgoType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract algorithm from OCI image: %w", err)
|
||||
}
|
||||
as.logger.Info("algorithm extracted from OCI image", "path", algorithmPath)
|
||||
files = []string{algorithmPath}
|
||||
}
|
||||
as.logger.Info("algorithm extracted from OCI image", "path", algorithmPath)
|
||||
} else {
|
||||
// Assume dataset
|
||||
files, err := oci.ExtractDataset(destDir, extractDir)
|
||||
files, err = oci.ExtractDataset(destDir, extractDir)
|
||||
if err != nil || len(files) == 0 {
|
||||
return nil, fmt.Errorf("failed to extract dataset from OCI image: %w", err)
|
||||
}
|
||||
// For now, take the first file found.
|
||||
// nolint:godox // TODO: Handle multiple files / directory structure if needed.
|
||||
// Set algorithmPath to the first file for SourceDir calculation later
|
||||
algorithmPath = files[0]
|
||||
as.logger.Info("dataset extracted from OCI image", "path", algorithmPath)
|
||||
as.logger.Info("dataset extracted from OCI image", "num_files", len(files))
|
||||
}
|
||||
|
||||
// Read algorithm file
|
||||
algorithmData, err := os.ReadFile(algorithmPath)
|
||||
// Determine which path to hash based on extraction results
|
||||
var hashPath string
|
||||
// For algorithms, we always hash the specific algorithm file found.
|
||||
// For datasets, if there's only one file, hash it directly.
|
||||
// If multiple files, hash the directory (which zips it).
|
||||
if len(files) == 1 {
|
||||
hashPath = files[0]
|
||||
} else {
|
||||
hashPath = extractDir
|
||||
}
|
||||
|
||||
// Calculate digest (matches internal.Checksum logic)
|
||||
resourceData, _, err := internal.Digest(hashPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read algorithm file: %w", err)
|
||||
return nil, fmt.Errorf("failed to calculate resource digest: %w", err)
|
||||
}
|
||||
|
||||
// Check for requirements.txt if algorithm
|
||||
// Read requirements file if found (only for algorithms)
|
||||
var reqData []byte
|
||||
if resourceType == "algorithm" {
|
||||
reqPath := filepath.Join(filepath.Dir(algorithmPath), "requirements.txt")
|
||||
if data, err := os.ReadFile(reqPath); err == nil {
|
||||
reqData = data
|
||||
as.logger.Info("found requirements.txt", "size", len(data))
|
||||
if requirementsPath != "" {
|
||||
reqData, err = os.ReadFile(requirementsPath)
|
||||
if err != nil {
|
||||
as.logger.Warn("failed to read requirements file", "path", requirementsPath, "error", err)
|
||||
} else {
|
||||
as.logger.Info("requirements.txt loaded", "size", len(reqData))
|
||||
}
|
||||
} else {
|
||||
// Fallback: check if requirements.txt exists in the same directory as the algorithm
|
||||
reqPath := filepath.Join(filepath.Dir(algorithmPath), "requirements.txt")
|
||||
if data, err := os.ReadFile(reqPath); err == nil {
|
||||
reqData = data
|
||||
as.logger.Info("found requirements.txt via fallback", "size", len(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
as.logger.Info("algorithm loaded", "size", len(algorithmData))
|
||||
as.logger.Info("resource loaded from OCI", "type", resourceType, "size", len(resourceData), "hash_path", hashPath)
|
||||
|
||||
return &DecryptedResource{
|
||||
Data: algorithmData,
|
||||
Data: resourceData,
|
||||
Requirements: reqData,
|
||||
SourceDir: filepath.Dir(algorithmPath),
|
||||
SourceDir: extractDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -845,6 +889,8 @@ func (as *agentService) runComputation(state statemachine.State) {
|
||||
as.publishEvent(Starting.String())(state)
|
||||
as.logger.Debug("computation run started")
|
||||
defer func() {
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
if as.runError != nil {
|
||||
as.sm.SendEvent(RunFailed)
|
||||
} else {
|
||||
@@ -852,12 +898,9 @@ func (as *agentService) runComputation(state statemachine.State) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil {
|
||||
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
|
||||
as.logger.Warn(as.runError.Error())
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
// Read algo file
|
||||
currentDir, _ := os.Getwd()
|
||||
algoFile := filepath.Join(currentDir, "algo")
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
|
||||
@@ -866,14 +909,25 @@ func (as *agentService) runComputation(state statemachine.State) {
|
||||
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
|
||||
as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
|
||||
}
|
||||
if err := os.Remove(algoFile); err != nil && !os.IsNotExist(err) {
|
||||
as.logger.Warn(fmt.Sprintf("error removing algorithm file: %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
// Read algo file
|
||||
currentDir, _ := os.Getwd()
|
||||
algoFile := filepath.Join(currentDir, "algo")
|
||||
if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil {
|
||||
as.mu.Lock()
|
||||
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
|
||||
as.mu.Unlock()
|
||||
as.logger.Warn(as.runError.Error())
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
|
||||
algoBytes, err := os.ReadFile(algoFile)
|
||||
if err != nil {
|
||||
as.mu.Lock()
|
||||
as.runError = fmt.Errorf("failed to read algo file: %w", err)
|
||||
as.mu.Unlock()
|
||||
as.logger.Warn(as.runError.Error())
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
@@ -891,14 +945,18 @@ func (as *agentService) runComputation(state statemachine.State) {
|
||||
// Datasets implicit on shared FS
|
||||
})
|
||||
if err != nil {
|
||||
as.mu.Lock()
|
||||
as.runError = err
|
||||
as.mu.Unlock()
|
||||
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Error != "" {
|
||||
as.mu.Lock()
|
||||
as.runError = errors.New(resp.Error)
|
||||
as.mu.Unlock()
|
||||
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error))
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
@@ -906,7 +964,9 @@ func (as *agentService) runComputation(state statemachine.State) {
|
||||
|
||||
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
|
||||
if err != nil {
|
||||
as.mu.Lock()
|
||||
as.runError = err
|
||||
as.mu.Unlock()
|
||||
as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
|
||||
+523
-6
@@ -4,6 +4,8 @@ package agent
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
@@ -46,6 +48,11 @@ func (m *MockOCIClient) PullAndDecrypt(ctx context.Context, source oci.ResourceS
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockOCIClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error {
|
||||
args := m.Called(ctx, ociDir, destFile)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
var (
|
||||
algoPath = "../test/manual/algo/lin_reg.py"
|
||||
reqPath = "../test/manual/algo/requirements.txt"
|
||||
@@ -1097,7 +1104,7 @@ func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
|
||||
algoContent := []byte("print('hello')")
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "main.py", string(algoContent))
|
||||
setupMinimalOCI(t, destDir, "main.py", algoContent)
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
@@ -1112,7 +1119,7 @@ func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/image",
|
||||
URL: "docker://test/algo-success",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
@@ -1131,7 +1138,57 @@ func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
|
||||
mockOCI.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func setupMinimalOCI(t *testing.T, ociDir, filename, content string) {
|
||||
func TestDownloadAlgorithmIfRemote_Docker_Success(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", AlgorithmReceived).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
dummyContent := []byte("dummy docker tar")
|
||||
dummyHash := sha3.Sum256(dummyContent)
|
||||
|
||||
mockOCI.On("ToDockerArchive", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destFile := args.String(2)
|
||||
err := os.WriteFile(destFile, dummyContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
AlgoType: "docker",
|
||||
Hash: dummyHash,
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-docker-success",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
|
||||
assert.Nil(t, svc.runError)
|
||||
assert.True(t, svc.algoReceived)
|
||||
sm.AssertExpectations(t)
|
||||
mockOCI.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func setupMinimalOCI(t *testing.T, ociDir, filename string, content []byte) {
|
||||
t.Helper()
|
||||
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
|
||||
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
|
||||
@@ -1149,7 +1206,8 @@ func setupMinimalOCI(t *testing.T, ociDir, filename, content string) {
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
require.NoError(t, tw.WriteHeader(hdr))
|
||||
_, err = tw.Write([]byte(content))
|
||||
_, err = tw.Write(content)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, tw.Close())
|
||||
@@ -1201,7 +1259,7 @@ func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
|
||||
dataContent := []byte("a,b,c\n1,2,3")
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "data.csv", string(dataContent))
|
||||
setupMinimalOCI(t, destDir, "data.csv", dataContent)
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
@@ -1217,7 +1275,7 @@ func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
|
||||
Hash: dataHash,
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/image",
|
||||
URL: "docker://test/data-success",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1234,3 +1292,462 @@ func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
|
||||
sm.AssertExpectations(t)
|
||||
mockOCI.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDownloadDatasetsIfRemote_Decompress(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", DataReceived).Return().Maybe()
|
||||
sm.On("SendEvent", RunFailed).Return().Maybe()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
|
||||
// Create a zip file in memory
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
f, err := zw.Create("test.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = f.Write([]byte("hello zip"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, zw.Close())
|
||||
zipData := buf.Bytes()
|
||||
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "data.zip", zipData)
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
dataHash := sha3.Sum256(zipData)
|
||||
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.zip",
|
||||
Hash: dataHash,
|
||||
Decompress: true,
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-decompress",
|
||||
},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
err = os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
|
||||
assert.Nil(t, svc.runError)
|
||||
assert.Len(t, svc.computation.Datasets, 0)
|
||||
// Check if file was decompressed
|
||||
decompressedFile := filepath.Join(algorithm.DatasetsDir, "test.txt")
|
||||
_, err = os.Stat(decompressedFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sm.AssertExpectations(t)
|
||||
mockOCI.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) {
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
t.Run("hash mismatch", func(t *testing.T) {
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "main.py", []byte("wrong content"))
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Hash: sha3.Sum256([]byte("expected content")),
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-hash-mismatch",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
assert.Error(t, svc.runError)
|
||||
assert.Contains(t, svc.runError.Error(), "algorithm hash mismatch")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("create algo file failure", func(t *testing.T) {
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
// Create a directory named "algo" to make file creation fail
|
||||
require.NoError(t, os.Mkdir("algo", 0o755))
|
||||
defer os.RemoveAll("algo")
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
algoContent := "print(1)"
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "main.py", []byte(algoContent))
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Hash: sha3.Sum256([]byte(algoContent)),
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-create-fail",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
assert.Error(t, svc.runError)
|
||||
assert.Contains(t, svc.runError.Error(), "error creating algorithm file")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
t.Run("extraction failure", func(t *testing.T) {
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
// Setup OCI with NO main.py or any algorithm file
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(destDir, "blobs"), 0o755))
|
||||
// Create a legit-looking but empty index.json
|
||||
require.NoError(t, os.WriteFile(filepath.Join(destDir, "index.json"), []byte(`{"schemaVersion":2,"manifests":[]}`), 0o644))
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/image",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
assert.Error(t, svc.runError)
|
||||
assert.Contains(t, svc.runError.Error(), "no manifests found")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) {
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
// Use a fresh mock in each subtest to avoid state pollution
|
||||
|
||||
t.Run("dataset create file failure", func(t *testing.T) {
|
||||
eventsSvc := mocks.NewService(t)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
// Create a directory named "data.csv" in datasets dir to make file creation fail
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(algorithm.DatasetsDir, "data.csv"), 0o755))
|
||||
defer os.RemoveAll(algorithm.DatasetsDir)
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
dataContent := "a,b,c"
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent))
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Hash: sha3.Sum256([]byte(dataContent)),
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-create-fail",
|
||||
},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("dataset hash mismatch", func(t *testing.T) {
|
||||
eventsSvc := mocks.NewService(t)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { _ = os.Chdir(origDir) }()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
dataContent := "wrong content"
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent))
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Hash: sha3.Sum256([]byte("expected content")),
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-mismatch",
|
||||
},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
if svc.runError == nil {
|
||||
t.Fatalf("runError should not be nil in hash mismatch test")
|
||||
}
|
||||
assert.Contains(t, svc.runError.Error(), "dataset data.csv hash mismatch")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("dataset unzip failure", func(t *testing.T) {
|
||||
eventsSvc := mocks.NewService(t)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { _ = os.Chdir(origDir) }()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunFailed).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
// Provide invalid zip content
|
||||
dataContent := "not a zip file"
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "data.zip", []byte(dataContent))
|
||||
}).Return(nil)
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.zip",
|
||||
Hash: sha3.Sum256([]byte(dataContent)),
|
||||
Decompress: true,
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-unzip-fail",
|
||||
},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
if svc.runError == nil {
|
||||
t.Fatalf("runError should not be nil in unzip failure test")
|
||||
}
|
||||
assert.Contains(t, svc.runError.Error(), "failed to unzip dataset")
|
||||
sm.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAlgo_RemoteSource(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("GetState").Return(ReceivingAlgorithm)
|
||||
sm.On("SendEvent", AlgorithmReceived).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
algoContent := []byte("print('remote algo')")
|
||||
algoHash := sha3.Sum256(algoContent)
|
||||
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "main.py", algoContent)
|
||||
}).Return(nil)
|
||||
|
||||
svc := &agentService{
|
||||
logger: slog.Default(),
|
||||
eventSvc: eventsSvc,
|
||||
sm: sm,
|
||||
ociClient: mockOCI,
|
||||
computation: Computation{
|
||||
Algorithm: Algorithm{
|
||||
Hash: algoHash,
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-remote",
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, "python"))
|
||||
err := svc.Algo(ctx, Algorithm{})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, svc.algoReceived)
|
||||
sm.AssertExpectations(t)
|
||||
mockOCI.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestData_RemoteSource(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("GetState").Return(ReceivingData)
|
||||
sm.On("SendEvent", DataReceived).Return().Once()
|
||||
|
||||
mockOCI := new(MockOCIClient)
|
||||
dataContent := []byte("remote data")
|
||||
dataHash := sha3.Sum256(dataContent)
|
||||
|
||||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
destDir := args.String(2)
|
||||
setupMinimalOCI(t, destDir, "data.csv", dataContent)
|
||||
}).Return(nil)
|
||||
|
||||
svc := &agentService{
|
||||
logger: slog.Default(),
|
||||
eventSvc: eventsSvc,
|
||||
sm: sm,
|
||||
ociClient: mockOCI,
|
||||
computation: Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Hash: dataHash,
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-remote",
|
||||
},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = svc.Data(ctx, Dataset{})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, svc.computation.Datasets, 0)
|
||||
sm.AssertExpectations(t)
|
||||
mockOCI.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestRunComputation_Success(t *testing.T) {
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
// Write a dummy algo file
|
||||
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
|
||||
|
||||
runnerCli := new(runnermocks.Client)
|
||||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
sm := &smmocks.StateMachine{}
|
||||
sm.On("SendEvent", RunComplete).Return().Once()
|
||||
|
||||
svc := &agentService{
|
||||
logger: slog.Default(),
|
||||
eventSvc: eventsSvc,
|
||||
sm: sm,
|
||||
runnerClient: runnerCli,
|
||||
computation: Computation{ID: "test-run"},
|
||||
}
|
||||
|
||||
svc.runComputation(Running)
|
||||
|
||||
assert.Nil(t, svc.runError)
|
||||
sm.AssertExpectations(t)
|
||||
runnerCli.AssertExpectations(t)
|
||||
}
|
||||
|
||||
@@ -144,6 +144,8 @@ func TestNewGetAttestationCmd(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
os.Remove(attestationFilePath)
|
||||
os.Remove(attestationReportJson)
|
||||
os.Remove(azureAttestResultFilePath)
|
||||
os.Remove(azureAttestTokenFilePath)
|
||||
})
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
|
||||
+16
-14
@@ -52,27 +52,29 @@ func DeleteFilesInDir(dirPath string) error {
|
||||
|
||||
// Checksum calculates the SHA3-256 checksum of the file or directory at path.
|
||||
func Checksum(path string) ([]byte, error) {
|
||||
_, sum, err := Digest(path)
|
||||
return sum, err
|
||||
}
|
||||
|
||||
// Digest returns the data used for checksumming and the checksum itself.
|
||||
func Digest(path string) ([]byte, []byte, error) {
|
||||
file, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if file.IsDir() {
|
||||
f, err := ZipDirectoryToMemory(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sum := sha3.Sum256(f)
|
||||
return sum[:], nil
|
||||
data, err = ZipDirectoryToMemory(path)
|
||||
} else {
|
||||
f, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sum := sha3.Sum256(f)
|
||||
return sum[:], nil
|
||||
data, err = os.ReadFile(path)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sum := sha3.Sum256(data)
|
||||
return data, sum[:], nil
|
||||
}
|
||||
|
||||
// ChecksumHex calculates the SHA3-256 checksum of the file or directory at path and returns it as a hex-encoded string.
|
||||
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCopyFile(t *testing.T) {
|
||||
@@ -149,6 +152,59 @@ func TestChecksumHex(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFile_ErrorPaths(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "copyfile_err_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
srcPath := filepath.Join(tempDir, "src.txt")
|
||||
require.NoError(t, os.WriteFile(srcPath, []byte("test"), 0o644))
|
||||
|
||||
t.Run("destination folder creation failure", func(t *testing.T) {
|
||||
// Create a file where a directory should be
|
||||
blockedDir := filepath.Join(tempDir, "blocked_dir")
|
||||
require.NoError(t, os.WriteFile(blockedDir, []byte("file"), 0o644))
|
||||
|
||||
dstPath := filepath.Join(blockedDir, "dst.txt")
|
||||
err := CopyFile(srcPath, dstPath)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("source file open failure", func(t *testing.T) {
|
||||
err := CopyFile("/non/existent/src", filepath.Join(tempDir, "dst.txt"))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteFilesInDir_ErrorPaths(t *testing.T) {
|
||||
t.Run("non-existent directory", func(t *testing.T) {
|
||||
err := DeleteFilesInDir("/non/existent/path")
|
||||
// os.ReadDir on non-existent path returns error, but function returns nil
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDigest_ErrorPaths(t *testing.T) {
|
||||
t.Run("non-existent file", func(t *testing.T) {
|
||||
_, _, err := Digest("/non/existent/path")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("directory digest failure", func(t *testing.T) {
|
||||
// This happens if some file inside the directory cannot be read
|
||||
tempDir, err := os.MkdirTemp("", "digest_err_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
blockedFile := filepath.Join(tempDir, "blocked.txt")
|
||||
require.NoError(t, os.WriteFile(blockedFile, []byte("test"), 0o000)) // No permissions
|
||||
|
||||
_, _, err = Digest(tempDir)
|
||||
// ZipDirectoryToMemory should fail due to permission error
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChecksumHex_NonExistentFile(t *testing.T) {
|
||||
_, err := ChecksumHex("nonexistent.txt")
|
||||
if err == nil {
|
||||
|
||||
+72
-39
@@ -8,49 +8,63 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ZipDirectoryToMemory(sourceDir string) ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
zipWriter := zip.NewWriter(buf)
|
||||
|
||||
var files []string
|
||||
err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
files = append(files, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
sort.Strings(files)
|
||||
|
||||
for _, path := range files {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(sourceDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zipHeader, err := zip.FileInfoHeader(info)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
zipHeader.Name = relPath
|
||||
zipHeader.Modified = time.Unix(0, 0) // Deterministic timestamp
|
||||
|
||||
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileToZip, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer fileToZip.Close()
|
||||
|
||||
_, err = io.Copy(zipWriterEntry, fileToZip)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := zipWriter.Close(); err != nil {
|
||||
@@ -68,45 +82,64 @@ func ZipDirectoryToTempFile(sourceDir string) (*os.File, error) {
|
||||
|
||||
zipWriter := zip.NewWriter(tmpFile)
|
||||
|
||||
var files []string
|
||||
err = filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
if !info.IsDir() {
|
||||
files = append(files, path)
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(sourceDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
zipHeader, err := zip.FileInfoHeader(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
zipHeader.Name = relPath
|
||||
|
||||
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileToZip, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fileToZip.Close()
|
||||
|
||||
_, err = io.Copy(zipWriterEntry, fileToZip)
|
||||
return err
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sort.Strings(files)
|
||||
|
||||
for _, path := range files {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(sourceDir, path)
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zipHeader, err := zip.FileInfoHeader(info)
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
}
|
||||
zipHeader.Name = relPath
|
||||
zipHeader.Modified = time.Unix(0, 0) // Deterministic timestamp
|
||||
|
||||
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileToZip, err := os.Open(path)
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
}
|
||||
defer fileToZip.Close()
|
||||
|
||||
_, err = io.Copy(zipWriterEntry, fileToZip)
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := zipWriter.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+190
-1
@@ -4,10 +4,14 @@ package internal
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestZipDirectoryToMemory(t *testing.T) {
|
||||
@@ -226,6 +230,112 @@ func TestZipDirectoryToTempFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnzipFromMemory_ErrorPaths(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "unzip_error_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a valid zip with one file
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
f, err := zw.Create("file.txt")
|
||||
require.NoError(t, err)
|
||||
_, err = f.Write([]byte("content"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
t.Run("mkdir failure", func(t *testing.T) {
|
||||
// Create a file where a directory should be
|
||||
blockedDir := filepath.Join(tempDir, "subdir")
|
||||
require.NoError(t, os.WriteFile(blockedDir, []byte("blocked"), 0o644))
|
||||
defer os.Remove(blockedDir)
|
||||
|
||||
// Create a zip that tries to create a file in that subdir
|
||||
var buf2 bytes.Buffer
|
||||
zw2 := zip.NewWriter(&buf2)
|
||||
_, err := zw2.Create("subdir/file.txt")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, zw2.Close())
|
||||
|
||||
err = UnzipFromMemory(buf2.Bytes(), tempDir)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("create file failure", func(t *testing.T) {
|
||||
// Create a directory where a file should be
|
||||
blockedFile := filepath.Join(tempDir, "blocked_file.txt")
|
||||
require.NoError(t, os.MkdirAll(blockedFile, 0o755))
|
||||
defer os.RemoveAll(blockedFile)
|
||||
|
||||
var buf2 bytes.Buffer
|
||||
zw2 := zip.NewWriter(&buf2)
|
||||
_, err := zw2.Create("blocked_file.txt")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, zw2.Close())
|
||||
|
||||
err = UnzipFromMemory(buf2.Bytes(), tempDir)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestZipDirectoryToMemory_ErrorPaths(t *testing.T) {
|
||||
t.Run("non-existent directory", func(t *testing.T) {
|
||||
_, err := ZipDirectoryToMemory("/non/existent/path")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestZipDirectoryToTempFile_ErrorPaths(t *testing.T) {
|
||||
t.Run("non-existent directory", func(t *testing.T) {
|
||||
_, err := ZipDirectoryToTempFile("/non/existent/path")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty source path", func(t *testing.T) {
|
||||
_, err := ZipDirectoryToTempFile("")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestZipDirectoryToMemory_NotADirectory(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "not_a_dir")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
_, err = ZipDirectoryToMemory(tempFile.Name())
|
||||
// filepath.Walk on a file succeeds and visits only that file
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestZipDirectoryToTempFile_NotADirectory(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "not_a_dir_tempfile")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
zf, err := ZipDirectoryToTempFile(tempFile.Name())
|
||||
assert.NoError(t, err)
|
||||
if err == nil {
|
||||
zf.Close()
|
||||
os.Remove(zf.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestZipDirectoryToMemory_OpenError(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "open_err_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
file := filepath.Join(tempDir, "unreadable.txt")
|
||||
require.NoError(t, os.WriteFile(file, []byte("test"), 0o000))
|
||||
|
||||
_, err = ZipDirectoryToMemory(tempDir)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -240,7 +350,6 @@ func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
|
||||
sourceDir: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ZipDirectoryToTempFile(tt.sourceDir)
|
||||
@@ -250,3 +359,83 @@ func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZipDirectoryToTempFile_InternalErrorPaths(t *testing.T) {
|
||||
t.Run("unreadable file in directory", func(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "unreadable_zip_temp")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
file := filepath.Join(tempDir, "unreadable.txt")
|
||||
require.NoError(t, os.WriteFile(file, []byte("test"), 0o000))
|
||||
|
||||
_, err = ZipDirectoryToTempFile(tempDir)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnzipFromMemory_MoreEdgeCases(t *testing.T) {
|
||||
t.Run("empty zip data", func(t *testing.T) {
|
||||
err := UnzipFromMemory([]byte{}, t.TempDir())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("zip with absolute paths", func(t *testing.T) {
|
||||
// This tests if UnzipFromMemory handles files with names like "/tmp/hacker.txt"
|
||||
// zip.NewReader usually doesn't allow absolute paths easily, but let's see.
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
_, err := zw.Create("/tmp/test.txt")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
tempDir := t.TempDir()
|
||||
err = UnzipFromMemory(buf.Bytes(), tempDir)
|
||||
assert.NoError(t, err)
|
||||
// It should be joined with tempDir, not written to /tmp/test.txt
|
||||
expectedPath := filepath.Join(tempDir, "/tmp/test.txt")
|
||||
_, err = os.Stat(expectedPath)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnzipFromMemory_InvalidReader(t *testing.T) {
|
||||
err := UnzipFromMemory([]byte("invalid"), t.TempDir())
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUnzipFromMemory_FileCreateError(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a zip with one file
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
_, err := zw.Create("file.txt")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
// Create a directory where the file should be
|
||||
blockedFile := filepath.Join(tempDir, "file.txt")
|
||||
require.NoError(t, os.MkdirAll(blockedFile, 0o755))
|
||||
|
||||
err = UnzipFromMemory(buf.Bytes(), tempDir)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUnzipFromMemory_DirCreateError(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a zip with one directory entry
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
_, err := zw.Create("subdir/")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
// Create a file where the directory should be
|
||||
blockedDir := filepath.Join(tempDir, "subdir")
|
||||
require.NoError(t, os.WriteFile(blockedDir, []byte("blocked"), 0o644))
|
||||
|
||||
err = UnzipFromMemory(buf.Bytes(), tempDir)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
+72
-44
@@ -31,22 +31,22 @@ type OCIIndex struct {
|
||||
} `json:"manifests"`
|
||||
}
|
||||
|
||||
// ExtractAlgorithm extracts the algorithm file from an OCI image directory.
|
||||
func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath string) (string, error) {
|
||||
// ExtractAlgorithm extracts the algorithm file and optionally requirements.txt from an OCI image directory.
|
||||
func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath, algoType string) (string, string, error) {
|
||||
// Read index.json to find manifest
|
||||
indexPath := filepath.Join(ociDir, "index.json")
|
||||
indexData, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read index.json: %w", err)
|
||||
return "", "", fmt.Errorf("failed to read index.json: %w", err)
|
||||
}
|
||||
|
||||
var index OCIIndex
|
||||
if err := json.Unmarshal(indexData, &index); err != nil {
|
||||
return "", fmt.Errorf("failed to parse index.json: %w", err)
|
||||
return "", "", fmt.Errorf("failed to parse index.json: %w", err)
|
||||
}
|
||||
|
||||
if len(index.Manifests) == 0 {
|
||||
return "", fmt.Errorf("no manifests found in index.json")
|
||||
return "", "", fmt.Errorf("no manifests found in index.json")
|
||||
}
|
||||
|
||||
// Get the first manifest digest
|
||||
@@ -56,7 +56,7 @@ func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath
|
||||
// Read manifest to find layers
|
||||
manifestData, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read manifest: %w", err)
|
||||
return "", "", fmt.Errorf("failed to read manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest struct {
|
||||
@@ -65,50 +65,64 @@ func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath
|
||||
} `json:"layers"`
|
||||
}
|
||||
if err := json.Unmarshal(manifestData, &manifest); err != nil {
|
||||
return "", fmt.Errorf("failed to parse manifest: %w", err)
|
||||
return "", "", fmt.Errorf("failed to parse manifest: %w", err)
|
||||
}
|
||||
|
||||
// Extract layers to find algorithm files
|
||||
logger.Debug("found layers in manifest", "count", len(manifest.Layers))
|
||||
var algorithmPath string
|
||||
var requirementsPath string
|
||||
var allSeenFiles []string
|
||||
|
||||
// Iterate layers in reverse order to find user code first (usually in top layers)
|
||||
// Process layers in reverse order (top layers first)
|
||||
for i := len(manifest.Layers) - 1; i >= 0; i-- {
|
||||
layer := manifest.Layers[i]
|
||||
layerPath := filepath.Join(ociDir, "blobs", strings.Replace(layer.Digest, ":", "/", 1))
|
||||
|
||||
// Try to extract and find algorithm file
|
||||
algoPath, seenFiles, err := extractLayerAndFindAlgorithm(logger, layerPath, destPath)
|
||||
algoP, reqP, seenFiles, err := extractLayerAndFindAlgorithm(logger, layerPath, destPath, algoType)
|
||||
if len(seenFiles) > 0 {
|
||||
allSeenFiles = append(allSeenFiles, seenFiles...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("error extracting layer %s: %v", layer.Digest, err))
|
||||
logger.Warn("failed to extract layer", "digest", layer.Digest, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if algoPath != "" {
|
||||
return algoPath, nil
|
||||
if algoP != "" && algorithmPath == "" {
|
||||
algorithmPath = algoP
|
||||
}
|
||||
if reqP != "" && requirementsPath == "" {
|
||||
requirementsPath = reqP
|
||||
}
|
||||
|
||||
// If we found both, we can stop
|
||||
if algorithmPath != "" && (algoType != "python" || requirementsPath != "") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no algorithm file found in OCI image layers (seen: %v)", allSeenFiles)
|
||||
if algorithmPath == "" {
|
||||
return "", "", fmt.Errorf("no algorithm file found. Seen files: %v", allSeenFiles)
|
||||
}
|
||||
|
||||
return algorithmPath, requirementsPath, nil
|
||||
}
|
||||
|
||||
// extractLayerAndFindAlgorithm extracts a layer and searches for algorithm files.
|
||||
func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath string) (string, []string, error) {
|
||||
func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath, algoType string) (string, string, []string, error) {
|
||||
// Open layer file
|
||||
layerFile, err := os.Open(layerPath)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to open layer: %w", err)
|
||||
return "", "", nil, fmt.Errorf("failed to open layer: %w", err)
|
||||
}
|
||||
defer layerFile.Close()
|
||||
|
||||
// Decompress gzip
|
||||
gzReader, err := gzip.NewReader(layerFile)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
return "", "", nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
@@ -116,7 +130,8 @@ func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath strin
|
||||
tarReader := tar.NewReader(gzReader)
|
||||
|
||||
var algorithmPath string
|
||||
var seenFiles []string
|
||||
var requirementsPath string
|
||||
seenFiles := []string{}
|
||||
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
@@ -124,7 +139,7 @@ func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath strin
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", seenFiles, fmt.Errorf("failed to read tar header: %w", err)
|
||||
return "", "", seenFiles, fmt.Errorf("failed to read tar header: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("inspecting file in layer", "name", header.Name, "type", header.Typeflag)
|
||||
@@ -137,7 +152,7 @@ func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath strin
|
||||
seenFiles = append(seenFiles, header.Name)
|
||||
|
||||
// Check if this is an algorithm file or requirements.txt
|
||||
isAlgo := isAlgorithmFile(header.Name)
|
||||
isAlgo := isAlgorithmFile(header.Name, header.Mode, algoType)
|
||||
isReq := filepath.Base(header.Name) == "requirements.txt"
|
||||
|
||||
if isAlgo || isReq {
|
||||
@@ -151,56 +166,69 @@ func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath strin
|
||||
targetPath := filepath.Join(destPath, cleanName)
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
||||
return "", seenFiles, fmt.Errorf("failed to create dir: %w", err)
|
||||
return "", "", seenFiles, fmt.Errorf("failed to create dir: %w", err)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return "", seenFiles, fmt.Errorf("failed to create file: %w", err)
|
||||
return "", "", seenFiles, fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
outFile.Close()
|
||||
return "", seenFiles, fmt.Errorf("failed to write file: %w", err)
|
||||
return "", "", seenFiles, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
outFile.Close()
|
||||
|
||||
if isAlgo {
|
||||
if isAlgo && algorithmPath == "" {
|
||||
algorithmPath = targetPath
|
||||
}
|
||||
// Continue scanning to extract other files (like requirements.txt)
|
||||
if isReq && requirementsPath == "" {
|
||||
requirementsPath = targetPath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return algorithmPath, seenFiles, nil
|
||||
return algorithmPath, requirementsPath, seenFiles, nil
|
||||
}
|
||||
|
||||
// isAlgorithmFile checks if a file is likely an algorithm file.
|
||||
func isAlgorithmFile(filename string) bool {
|
||||
// Common algorithm file extensions
|
||||
algorithmExts := []string{".py", ".wasm", ".wat", ".js", ".sh"}
|
||||
// isAlgorithmFile checks if a file is likely an algorithm file based on its name, mode and expected algorithm type.
|
||||
func isAlgorithmFile(filename string, mode int64, algoType string) bool {
|
||||
base := filepath.Base(filename)
|
||||
baseLower := strings.ToLower(base)
|
||||
|
||||
// Common algorithm file names
|
||||
algorithmNames := []string{"algorithm", "main", "run", "execute"}
|
||||
|
||||
base := filepath.Base(filename)
|
||||
baseLower := strings.ToLower(base)
|
||||
|
||||
// Check extensions
|
||||
for _, ext := range algorithmExts {
|
||||
if strings.HasSuffix(baseLower, ext) {
|
||||
return true
|
||||
switch algoType {
|
||||
case "python":
|
||||
return strings.HasSuffix(baseLower, ".py")
|
||||
case "wasm":
|
||||
return strings.HasSuffix(baseLower, ".wasm") || strings.HasSuffix(baseLower, ".wat")
|
||||
case "bin":
|
||||
// Ensure it doesn't have a known non-binary extension
|
||||
nonBinExts := []string{".py", ".wasm", ".wat", ".js", ".sh", ".csv", ".json", ".txt", ".md"}
|
||||
for _, ext := range nonBinExts {
|
||||
if strings.HasSuffix(baseLower, ext) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check common names
|
||||
for _, name := range algorithmNames {
|
||||
if strings.Contains(baseLower, name) {
|
||||
return true
|
||||
// Check for common names
|
||||
for _, name := range algorithmNames {
|
||||
if strings.Contains(baseLower, name) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check if it's executable (at least one 'x' bit set)
|
||||
return mode&0o111 != 0
|
||||
case "docker":
|
||||
// Docker algorithms are the whole image, this function shouldn't be used for them
|
||||
return false
|
||||
default:
|
||||
// Unknown or empty algoType - no generic fallback to ensure explicit type usage
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractDataset extracts dataset files from an OCI image directory.
|
||||
@@ -328,7 +356,7 @@ func extractLayerDataFiles(layerPath, destPath string) ([]string, error) {
|
||||
|
||||
// isDataFile checks if a file is likely a dataset file.
|
||||
func isDataFile(filename string) bool {
|
||||
dataExts := []string{".csv", ".json", ".txt", ".parquet", ".arrow", ".dat"}
|
||||
dataExts := []string{".csv", ".json", ".txt", ".parquet", ".arrow", ".dat", ".zip", ".tar", ".gz", ".tgz", ".tar.gz"}
|
||||
|
||||
baseLower := strings.ToLower(filepath.Base(filename))
|
||||
|
||||
|
||||
+246
-29
@@ -24,28 +24,28 @@ func TestIsAlgorithmFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
mode int64
|
||||
algoType string
|
||||
want bool
|
||||
}{
|
||||
{"Python file", "algorithm.py", true},
|
||||
{"WASM file", "module.wasm", true},
|
||||
{"WAT file", "module.wat", true},
|
||||
{"JavaScript file", "script.js", true},
|
||||
{"Shell script", "run.sh", true},
|
||||
{"Main python file", "main.py", true},
|
||||
{"Execute file", "execute.py", true},
|
||||
{"Algorithm name in path", "src/algorithm_v2.py", true},
|
||||
{"Random python file", "helper.py", true},
|
||||
{"CSV data file", "data.csv", false},
|
||||
{"JSON config file", "config.json", false},
|
||||
{"Text file", "readme.txt", false},
|
||||
{"Binary file", "data.bin", false},
|
||||
{"Uppercase extension", "MAIN.PY", true},
|
||||
{"Mixed case", "Algorithm.Py", true},
|
||||
{"Python file", "algorithm.py", 0o644, "python", true},
|
||||
{"WASM file", "module.wasm", 0o644, "wasm", true},
|
||||
{"WAT file", "module.wat", 0o644, "wasm", true},
|
||||
{"Python file as bin", "algorithm.py", 0o755, "bin", false},
|
||||
{"Main python file", "main.py", 0o644, "python", true},
|
||||
{"Binary file with common name", "algorithm", 0o644, "bin", true},
|
||||
{"Binary file with common name run", "run", 0o644, "bin", true},
|
||||
{"Executable binary", "my-app", 0o755, "bin", true},
|
||||
{"CSV data file", "data.csv", 0o755, "python", false},
|
||||
{"JSON config file", "config.json", 0o755, "wasm", false},
|
||||
{"Text file", "readme.txt", 0o755, "bin", false},
|
||||
{"Uppercase extension", "MAIN.PY", 0o644, "python", true},
|
||||
{"Mixed case", "Algorithm.Py", 0o644, "python", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isAlgorithmFile(tt.filename)
|
||||
got := isAlgorithmFile(tt.filename, tt.mode, tt.algoType)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func TestExtractAlgorithm(t *testing.T) {
|
||||
|
||||
t.Run("missing index.json", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
_, err := ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
|
||||
_, _, err := ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir(), "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to read index.json")
|
||||
})
|
||||
@@ -93,7 +93,7 @@ func TestExtractAlgorithm(t *testing.T) {
|
||||
err := os.WriteFile(filepath.Join(tempDir, "index.json"), []byte("not json"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
|
||||
_, _, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir(), "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse index.json")
|
||||
})
|
||||
@@ -105,14 +105,14 @@ func TestExtractAlgorithm(t *testing.T) {
|
||||
err := os.WriteFile(filepath.Join(tempDir, "index.json"), data, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
|
||||
_, _, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir(), "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no manifests found")
|
||||
})
|
||||
|
||||
t.Run("successful extraction", func(t *testing.T) {
|
||||
ociDir, destDir := setupTestOCIImage(t, "algorithm.py", testPythonScript)
|
||||
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
algoPath, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, algoPath)
|
||||
assert.Contains(t, algoPath, "algorithm.py")
|
||||
@@ -471,7 +471,7 @@ func TestExtractAlgorithmWithRequirements(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
algoPath, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, algoPath, "main.py")
|
||||
|
||||
@@ -537,7 +537,7 @@ func TestExtractAlgorithmNoAlgoFile(t *testing.T) {
|
||||
indexData, _ := json.Marshal(index)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
_, _, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no algorithm file found")
|
||||
})
|
||||
@@ -600,6 +600,41 @@ func TestExtractDatasetNoDataFiles(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no dataset files found")
|
||||
})
|
||||
|
||||
t.Run("corrupt layer file", func(t *testing.T) {
|
||||
ociDir := t.TempDir()
|
||||
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
|
||||
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
|
||||
|
||||
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "layer123"), []byte("not a gzip"), 0o644))
|
||||
|
||||
manifest := struct {
|
||||
Layers []struct {
|
||||
Digest string `json:"digest"`
|
||||
} `json:"layers"`
|
||||
}{
|
||||
Layers: []struct {
|
||||
Digest string `json:"digest"`
|
||||
}{{Digest: "sha256:layer123"}},
|
||||
}
|
||||
manifestData, _ := json.Marshal(manifest)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
|
||||
|
||||
index := OCIIndex{
|
||||
SchemaVersion: 2,
|
||||
Manifests: []struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int `json:"size"`
|
||||
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
|
||||
}
|
||||
indexData, _ := json.Marshal(index)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
// ExtractDataset logs a warning and continues if a layer fails, but if ALL fail it errors
|
||||
_, err := ExtractDataset(ociDir, t.TempDir())
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractAlgorithmInvalidManifest(t *testing.T) {
|
||||
@@ -626,7 +661,7 @@ func TestExtractAlgorithmInvalidManifest(t *testing.T) {
|
||||
indexData, _ := json.Marshal(index)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse manifest")
|
||||
})
|
||||
@@ -654,7 +689,7 @@ func TestExtractAlgorithmMissingManifest(t *testing.T) {
|
||||
indexData, _ := json.Marshal(index)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to read manifest")
|
||||
})
|
||||
@@ -723,7 +758,7 @@ func TestExtractAlgorithmWithDirectory(t *testing.T) {
|
||||
indexData, _ := json.Marshal(index)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
algoPath, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, algoPath, "main.py")
|
||||
})
|
||||
@@ -795,7 +830,7 @@ func TestExtractAlgorithmPathTraversal(t *testing.T) {
|
||||
indexData, _ := json.Marshal(index)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
algoPath, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, algoPath, "algorithm.py")
|
||||
|
||||
@@ -815,7 +850,7 @@ func TestExtractAlgorithmErrorPathsAdditional(t *testing.T) {
|
||||
err := os.WriteFile(layerPath, []byte("not gzip"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
_, _, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no algorithm file found")
|
||||
})
|
||||
@@ -833,7 +868,7 @@ func TestExtractAlgorithmErrorPathsAdditional(t *testing.T) {
|
||||
err = os.WriteFile(layerPath, buf.Bytes(), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
_, _, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no algorithm file found")
|
||||
})
|
||||
@@ -867,7 +902,7 @@ func TestExtractAlgorithmErrorPathsAdditional(t *testing.T) {
|
||||
indexData, _ := json.Marshal(index)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
|
||||
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no algorithm file found")
|
||||
})
|
||||
@@ -918,3 +953,185 @@ func TestExtractDatasetErrorPathsAdditional(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "no dataset files found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractAlgorithmAdditionalTypes(t *testing.T) {
|
||||
t.Run("isAlgorithmFile additional types", func(t *testing.T) {
|
||||
assert.False(t, isAlgorithmFile("any", 0o644, "docker"))
|
||||
assert.False(t, isAlgorithmFile("any", 0o644, "unknown"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractAlgorithmErrorPathsInternal(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
|
||||
t.Run("failed to create directory", func(t *testing.T) {
|
||||
ociDir, destDir := setupTestOCIImage(t, "algorithm.py", "print('hello')")
|
||||
|
||||
// Create a file where a directory should be
|
||||
blockedDir := filepath.Join(destDir, "blocked")
|
||||
require.NoError(t, os.WriteFile(blockedDir, []byte("data"), 0o644))
|
||||
|
||||
// Try to extract an algorithm that would need to create a directory where a file exists
|
||||
layerPath := filepath.Join(ociDir, "blobs", "sha256", "layer123")
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gw)
|
||||
hdr := &tar.Header{
|
||||
Name: "blocked/main.py",
|
||||
Mode: 0o644,
|
||||
Size: int64(len("print(1)")),
|
||||
}
|
||||
require.NoError(t, tw.WriteHeader(hdr))
|
||||
_, _ = tw.Write([]byte("print(1)"))
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
require.NoError(t, os.WriteFile(layerPath, buf.Bytes(), 0o644))
|
||||
|
||||
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed to create file", func(t *testing.T) {
|
||||
ociDir, destDir := setupTestOCIImage(t, "algorithm.py", "print('hello')")
|
||||
|
||||
// Create a directory where a file should be
|
||||
blockedFile := filepath.Join(destDir, "algorithm.py")
|
||||
require.NoError(t, os.MkdirAll(blockedFile, 0o755))
|
||||
|
||||
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractDatasetErrorPathsInternal(t *testing.T) {
|
||||
t.Run("failed to create directory for dataset", func(t *testing.T) {
|
||||
ociDir, destDir := setupTestOCIImage(t, "data.csv", "a,b,c")
|
||||
|
||||
blockedDir := filepath.Join(destDir, "blocked")
|
||||
require.NoError(t, os.WriteFile(blockedDir, []byte("data"), 0o644))
|
||||
|
||||
layerPath := filepath.Join(ociDir, "blobs", "sha256", "layer123")
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gw)
|
||||
hdr := &tar.Header{
|
||||
Name: "blocked/data.csv",
|
||||
Mode: 0o644,
|
||||
Size: int64(len("a,b")),
|
||||
}
|
||||
require.NoError(t, tw.WriteHeader(hdr))
|
||||
_, _ = tw.Write([]byte("a,b"))
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
require.NoError(t, os.WriteFile(layerPath, buf.Bytes(), 0o644))
|
||||
|
||||
_, err := ExtractDataset(ociDir, destDir)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractAlgorithm_PythonNoRequirements(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
ociDir, destDir := setupTestOCIImage(t, "main.py", testPythonScript)
|
||||
algoPath, reqPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, algoPath)
|
||||
assert.Empty(t, reqPath)
|
||||
}
|
||||
|
||||
func TestExtractDataset_MultipleLayers(t *testing.T) {
|
||||
ociDir := t.TempDir()
|
||||
destDir := t.TempDir()
|
||||
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
|
||||
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
|
||||
|
||||
createLayer := func(name, filename, content string) string {
|
||||
path := filepath.Join(blobsDir, name)
|
||||
f, err := os.Create(path)
|
||||
require.NoError(t, err)
|
||||
gw := gzip.NewWriter(f)
|
||||
tw := tar.NewWriter(gw)
|
||||
hdr := &tar.Header{Name: filename, Mode: 0o644, Size: int64(len(content))}
|
||||
err = tw.WriteHeader(hdr)
|
||||
require.NoError(t, err)
|
||||
_, err = tw.Write([]byte(content))
|
||||
require.NoError(t, err)
|
||||
err = tw.Close()
|
||||
require.NoError(t, err)
|
||||
err = gw.Close()
|
||||
require.NoError(t, err)
|
||||
err = f.Close()
|
||||
require.NoError(t, err)
|
||||
return "sha256:" + name
|
||||
}
|
||||
|
||||
layer1 := createLayer("l1", "data1.csv", "1,2")
|
||||
layer2 := createLayer("l2", "data2.csv", "3,4")
|
||||
|
||||
manifest := struct {
|
||||
Layers []struct {
|
||||
Digest string `json:"digest"`
|
||||
} `json:"layers"`
|
||||
}{
|
||||
Layers: []struct {
|
||||
Digest string `json:"digest"`
|
||||
}{{Digest: layer1}, {Digest: layer2}},
|
||||
}
|
||||
manifestData, err := json.Marshal(manifest)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "m1"), manifestData, 0o644))
|
||||
|
||||
index := OCIIndex{
|
||||
SchemaVersion: 2,
|
||||
Manifests: []struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int `json:"size"`
|
||||
}{{Digest: "sha256:m1", Size: len(manifestData)}},
|
||||
}
|
||||
indexData, err := json.Marshal(index)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
files, err := ExtractDataset(ociDir, destDir)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, files, 2)
|
||||
}
|
||||
|
||||
func TestExtractAlgorithm_ErrorPaths(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
|
||||
t.Run("invalid layer gzip", func(t *testing.T) {
|
||||
ociDir := t.TempDir()
|
||||
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
|
||||
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "l1"), []byte("not gzip"), 0o644))
|
||||
|
||||
manifest := struct {
|
||||
Layers []struct {
|
||||
Digest string `json:"digest"`
|
||||
} `json:"layers"`
|
||||
}{
|
||||
Layers: []struct {
|
||||
Digest string `json:"digest"`
|
||||
}{{Digest: "sha256:l1"}},
|
||||
}
|
||||
manifestData, _ := json.Marshal(manifest)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "m1"), manifestData, 0o644))
|
||||
|
||||
index := OCIIndex{
|
||||
SchemaVersion: 2,
|
||||
Manifests: []struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int `json:"size"`
|
||||
}{{Digest: "sha256:m1", Size: len(manifestData)}},
|
||||
}
|
||||
indexData, _ := json.Marshal(index)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||||
|
||||
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, t.TempDir(), "bin")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no algorithm file found")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -109,6 +109,23 @@ func (s *SkopeoClient) Inspect(ctx context.Context, imageRef string) (*ImageMani
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToDockerArchive converts an OCI directory to a Docker archive tarball.
|
||||
func (s *SkopeoClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error {
|
||||
args := []string{"copy", "--insecure-policy", "--src-tls-verify=false", "--dest-tls-verify=false", "oci:" + ociDir, "docker-archive:" + destFile}
|
||||
|
||||
cmd := exec.CommandContext(ctx, s.skopeoPath, args...)
|
||||
cmd.Env = append(os.Environ(),
|
||||
OCICryptKeyproviderConfig+"="+DefaultOCICryptConfig)
|
||||
cmd.Dir = s.workDir
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("skopeo copy to docker-archive failed: %w\nOutput: %s", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLocalImagePath returns the path to a local OCI image directory.
|
||||
func (s *SkopeoClient) GetLocalImagePath(name string) string {
|
||||
return filepath.Join(s.workDir, name)
|
||||
|
||||
@@ -183,3 +183,19 @@ func TestSkopeoClientPullAndDecryptEncrypted(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "skopeo copy failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSkopeoClient_ToDockerArchive(t *testing.T) {
|
||||
workDir := t.TempDir()
|
||||
client, err := NewSkopeoClient(workDir)
|
||||
if err != nil {
|
||||
t.Skip("skopeo not installed, skipping test")
|
||||
}
|
||||
|
||||
t.Run("invalid oci directory", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
destFile := filepath.Join(t.TempDir(), "archive.tar")
|
||||
err := client.ToDockerArchive(ctx, "/non/existent/oci/dir", destFile)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "skopeo copy to docker-archive failed")
|
||||
})
|
||||
}
|
||||
|
||||
+11
-5
@@ -40,6 +40,11 @@ The service is configured using environment variables from the following table.
|
||||
| `-algo-kbs-path` | Algorithm KBS resource path (e.g., 'default/key/algo-key') |
|
||||
| `-dataset-source-urls` | Comma-separated dataset source URLs |
|
||||
| `-dataset-kbs-paths` | Comma-separated dataset KBS resource paths |
|
||||
| `-algo-type` | Algorithm execution type (binary, python, docker, etc.) |
|
||||
| `-algo-args` | Comma-separated algorithm arguments |
|
||||
| `-algo-hash` | Expected SHA3-256 hash of decrypted algorithm (hex) |
|
||||
| `-dataset-hash` | Expected SHA3-256 hash of decrypted dataset (hex) |
|
||||
| `-dataset-decompress` | Whether to decompress datasets (true,false) |
|
||||
|
||||
### Optional Flags
|
||||
|
||||
@@ -114,11 +119,12 @@ go run ./test/cvms/main.go \
|
||||
|
||||
## Notes
|
||||
|
||||
- **Either** `-algo-path` **OR** (`-algo-source-url` AND `-algo-kbs-path`) must be provided
|
||||
- When using remote datasets, `-dataset-source-urls` and `-dataset-kbs-paths` must have the same number of comma-separated values
|
||||
- The `-kbs-url` flag should be provided when using any remote resources
|
||||
- For remote resources, the hash values in the manifest are currently placeholders (all zeros). In production, these should be the actual hashes of the **decrypted** data
|
||||
- See [TESTING_REMOTE_RESOURCES.md](../TESTING_REMOTE_RESOURCES.md) for a complete guide on testing remote resource downloads with KBS attestation
|
||||
- **Either** `-algo-path` **OR** (`-algo-source-url` AND `-algo-kbs-path`) must be provided.
|
||||
- When using remote datasets, `-dataset-source-urls` and `-dataset-kbs-paths` must have the same number of comma-separated values.
|
||||
- The `-kbs-url` flag should be provided when using any remote resources.
|
||||
- **Checksum Verification**: For remote resources, you must provide the actual SHA3-256 hash of the **decrypted plaintext** content via `-algo-hash` and `-dataset-hash`. The Agent will verify this hash after downloading and decrypting the resource.
|
||||
- **Calculating Hashes**: Use `cocos-cli checksum <path>` on your local source files (or directories) to generate the correct hash for the manifest.
|
||||
- See [TESTING_REMOTE_RESOURCES.md](../../agent/TESTING_REMOTE_RESOURCES.md) for a complete guide on testing remote resource downloads with KBS attestation.
|
||||
|
||||
## Architecture
|
||||
|
||||
|
||||
+25
-9
@@ -140,13 +140,15 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
s.logger.Error(fmt.Sprintf("data file does not exist: %s", dataPath))
|
||||
return
|
||||
}
|
||||
dataHash, err := internal.Checksum(dataPath)
|
||||
dataHash, err := internal.ChecksumHex(dataPath)
|
||||
if err != nil {
|
||||
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
|
||||
return
|
||||
}
|
||||
s.logger.Info("local dataset checksum", "path", dataPath, "hash", dataHash)
|
||||
|
||||
datasets = append(datasets, &cvms.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes})
|
||||
hashBytes, _ := hex.DecodeString(dataHash)
|
||||
datasets = append(datasets, &cvms.Dataset{Hash: hashBytes, UserKey: pubPem.Bytes})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,11 +172,16 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
algoHashBytes = make([]byte, 32)
|
||||
}
|
||||
|
||||
var algoArgs []string
|
||||
if algoArgsString != "" {
|
||||
algoArgs = strings.Split(algoArgsString, ",")
|
||||
}
|
||||
|
||||
algorithm = &cvms.Algorithm{
|
||||
Hash: algoHashBytes,
|
||||
UserKey: pubPem.Bytes,
|
||||
AlgoType: algoType,
|
||||
AlgoArgs: strings.Split(algoArgsString, ","),
|
||||
AlgoArgs: algoArgs,
|
||||
Source: &cvms.Source{
|
||||
Type: "oci-image",
|
||||
Url: algoSourceURL,
|
||||
@@ -184,16 +191,25 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
}
|
||||
} else {
|
||||
// Direct upload mode - use local file
|
||||
if algoPath == "" {
|
||||
s.logger.Error("algorithm path is required when not using remote source")
|
||||
return
|
||||
}
|
||||
algoHash, err := internal.Checksum(algoPath)
|
||||
fileHash, err := internal.ChecksumHex(algoPath)
|
||||
if err != nil {
|
||||
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
|
||||
return
|
||||
}
|
||||
algorithm = &cvms.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes}
|
||||
s.logger.Info("local algorithm checksum", "path", algoPath, "hash", fileHash)
|
||||
|
||||
var algoArgs []string
|
||||
if algoArgsString != "" {
|
||||
algoArgs = strings.Split(algoArgsString, ",")
|
||||
}
|
||||
|
||||
hashBytes, _ := hex.DecodeString(fileHash)
|
||||
algorithm = &cvms.Algorithm{
|
||||
Hash: hashBytes,
|
||||
UserKey: pubPem.Bytes,
|
||||
AlgoType: algoType,
|
||||
AlgoArgs: algoArgs,
|
||||
}
|
||||
}
|
||||
|
||||
// Build KBS config
|
||||
|
||||
Reference in New Issue
Block a user