NOISSUE - Enhance OCI image extraction to return algorithm and requirements paths, and add deferred cleanup for temporary files (#586)
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled

* feat: Enhance OCI image extraction to return algorithm and requirements paths, and add deferred cleanup for temporary files.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: implement deterministic zipping and enhance checksum verification for resources

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update component build sources, add gRPC health checks to the CVM server, and refine algorithm argument handling and documentation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* docs: Update remote resources testing guide with `sudo` for KBS, algorithm result saving, `requirements.txt`, and `algo-args` for RVPS.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Explicitly ignore `stderr.Write` return values and add minor whitespace in tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add comprehensive error path and edge case tests for file, zip, OCI, and agent components.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add mutexes for thread-safe algorithm execution and expand recognized data file extensions to include common archive formats.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add OCI extraction tests for Python algorithms and multi-layer datasets, refactor algorithm execution for testability, and enhance algorithm stop and error handling tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: Add error assertions to OCI extraction test helpers and remove an unused mock exec command.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: Improve error handling test coverage for algorithm execution and OCI resource extraction.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Improve algorithm process termination, enhance computation error handling, and add concurrency safety to agent service.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2026-03-27 16:23:52 +03:00
committed by GitHub
parent 80bf813c48
commit b44780df95
24 changed files with 1799 additions and 229 deletions
+25 -10
View File
@@ -92,7 +92,7 @@ EOF
mkdir -p kbs-data/as kbs-data/rvps kbs-data/repository
# Start KBS
../target/release/kbs --config-file kbs-config.toml
sudo ../target/release/kbs --config-file kbs-config.toml
```
KBS will listen on `http://localhost:8080`
@@ -115,6 +115,7 @@ cat > lin_reg.py << 'EOF'
import pandas as pd
from sklearn.linear_model import LinearRegression
import sys
import os
# Load dataset
data = pd.read_csv(sys.argv[1])
@@ -126,34 +127,46 @@ model = LinearRegression()
model.fit(X, y)
# Save results
os.makedirs("results", exist_ok=True)
with open("results/output.txt", "w") as f:
f.write(f"Coefficients: {model.coef_}\n")
f.write(f"Intercept: {model.intercept_}\n")
print(f"Coefficients: {model.coef_}")
print(f"Intercept: {model.intercept_}")
EOF
# 2. Create a Dockerfile
# 2. Create requirements.txt
cat > requirements.txt << 'EOF'
pandas
scikit-learn
EOF
# 3. Create a Dockerfile
cat > Dockerfile << 'EOF'
FROM python:3.9-slim
RUN pip install pandas scikit-learn
COPY lin_reg.py /app/algorithm.py
COPY requirements.txt /app/requirements.txt
WORKDIR /app
ENTRYPOINT ["python", "algorithm.py"]
EOF
# 3. Build the image
# 4. Build the image
docker build -t localhost:5000/lin-reg-algo:v1.0 .
docker push localhost:5000/lin-reg-algo:v1.0
# 4. Generate and store key
# 5. Generate and store key
openssl rand -out algo.key 32
# 5. Store key in KBS using kbs-client
# 6. Store key in KBS using kbs-client
../target/release/kbs-client --url http://localhost:8080 config \
--auth-private-key kbs-admin.key \
set-resource \
--path default/key/algo-key \
--resource-file algo.key
# 6. Encrypt the image using Host Skopeo + Docker Keyprovider
# 7. Encrypt the image using Host Skopeo + Docker Keyprovider
# Start Keyprovider in background
docker run -d --rm --name keyprovider --network host \
-v "$PWD:/work" -w /work \
@@ -255,10 +268,12 @@ HOST_IP=$(ip -4 addr show | grep -oP '(?<=inet\s)\d+(\.\d+){3}' | grep -v 127.0.
Start CVMS server:
```bash
# Calculate SHA3-256 of decrypted files using cocos-cli
# Calculate SHA3-256 of decrypted files using cocos-cli or cvms-test
# NOTE: We use the hash of the original plaintext files, as the Agent validates the decrypted content.
# Redirect stderr to stdout (2>&1) because cocos-cli prints to stderr
# For single files, use the file hash. For directories, use the hash of the directory (which the tools zip deterministically).
ALGO_HASH=$(./build/cocos-cli checksum lin_reg.py 2>&1 | awk '{print $NF}')
DATASET_HASH=$(./build/cocos-cli checksum iris.csv 2>&1 | awk '{print $NF}')
go build -o build/cvms-test ./test/cvms/main.go
@@ -266,11 +281,11 @@ HOST=$HOST_IP PORT=7001 ./build/cvms-test \
-public-key-path ./public.pem \
-attested-tls-bool false \
-kbs-url http://$HOST_IP:8080 \
-algo-type oci-image \
-algo-type python \
-algo-source-url docker://$HOST_IP:5000/encrypted-lin-reg:v1.0 \
-algo-kbs-path default/key/algo-key \
-algo-hash $ALGO_HASH \
-dataset-type oci-image \
-algo-args datasets/dataset_0.csv \
-dataset-source-urls docker://$HOST_IP:5000/encrypted-iris:v1.0 \
-dataset-kbs-paths default/key/dataset-key \
-dataset-hash $DATASET_HASH
+14 -6
View File
@@ -3,16 +3,21 @@
package binary
import (
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"sync"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events"
)
var execCommand = exec.Command
var _ algorithm.Algorithm = (*binary)(nil)
type binary struct {
@@ -21,6 +26,7 @@ type binary struct {
stdout io.Writer
args []string
cmd *exec.Cmd
mu sync.Mutex
}
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string, cmpID string) algorithm.Algorithm {
@@ -33,13 +39,16 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
}
func (b *binary) Run() error {
b.cmd = exec.Command(b.algoFile, b.args...)
b.mu.Lock()
b.cmd = execCommand(b.algoFile, b.args...)
b.cmd.Stderr = b.stderr
b.cmd.Stdout = b.stdout
if err := b.cmd.Start(); err != nil {
b.mu.Unlock()
return fmt.Errorf("error starting algorithm: %v", err)
}
b.mu.Unlock()
if err := b.cmd.Wait(); err != nil {
return fmt.Errorf("algorithm execution error: %v", err)
@@ -49,11 +58,10 @@ func (b *binary) Run() error {
}
func (b *binary) Stop() error {
if b.cmd == nil {
return nil
}
b.mu.Lock()
defer b.mu.Unlock()
if b.cmd.ProcessState != nil && b.cmd.ProcessState.Exited() {
if b.cmd == nil {
return nil
}
@@ -61,7 +69,7 @@ func (b *binary) Stop() error {
return nil
}
if err := b.cmd.Process.Kill(); err != nil {
if err := b.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
return fmt.Errorf("error stopping algorithm: %v", err)
}
+70
View File
@@ -4,10 +4,14 @@ package binary
import (
"bytes"
"io"
"log/slog"
"os"
"os/exec"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
@@ -73,6 +77,7 @@ func TestBinaryRun(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args, "").(*binary)
@@ -98,3 +103,68 @@ func TestBinaryRun(t *testing.T) {
})
}
}
func TestStop(t *testing.T) {
t.Run("stop nil cmd", func(t *testing.T) {
b := &binary{}
err := b.Stop()
assert.NoError(t, err)
})
t.Run("stop with running process", func(t *testing.T) {
b := &binary{
algoFile: "sleep",
args: []string{"10"},
}
if err := b.Run(); err != nil {
t.Fatalf("Failed to start command: %v", err)
}
err := b.Stop()
assert.NoError(t, err)
// Verify it actually stopped
_ = b.cmd.Wait()
})
t.Run("stop already exited", func(t *testing.T) {
b := &binary{
algoFile: "echo",
args: []string{"test"},
stdout: io.Discard,
stderr: io.Discard,
}
if err := b.Run(); err != nil {
t.Fatal(err)
}
err := b.Stop()
assert.NoError(t, err)
})
}
func TestRunError(t *testing.T) {
// Mock execCommand to return an error on Start
oldExecCommand := execCommand
execCommand = mockExecCommandError
defer func() { execCommand = oldExecCommand }()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
b := NewAlgorithm(logger, eventsSvc, "test", nil, "").(*binary)
err := b.Run()
assert.Error(t, err)
}
func mockExecCommandError(command string, args ...string) *exec.Cmd {
// This will make Start() fail if we use a non-existent binary
return exec.Command("non_existent_binary_for_sure_12345")
}
func TestHelperProcess(t *testing.T) {
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
return
}
os.Exit(0)
}
+16 -9
View File
@@ -4,12 +4,14 @@ package python
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"sync"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
@@ -40,6 +42,7 @@ type python struct {
requirementsFile string
args []string
cmd *exec.Cmd
mu sync.Mutex
}
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string, cmpID string) algorithm.Algorithm {
@@ -60,6 +63,12 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requir
func (p *python) Run() error {
venvPath := "venv"
defer func() {
if err := os.RemoveAll(venvPath); err != nil {
_, _ = p.stderr.Write([]byte(fmt.Sprintf("error removing virtual environment: %v\n", err)))
}
}()
createVenvCmd := exec.Command(p.runtime, "-m", "venv", venvPath)
createVenvCmd.Stderr = p.stderr
createVenvCmd.Stdout = p.stdout
@@ -86,31 +95,29 @@ func (p *python) Run() error {
}
args := append([]string{p.algoFile}, p.args...)
p.mu.Lock()
p.cmd = exec.Command(pythonPath, args...)
p.cmd.Stderr = p.stderr
p.cmd.Stdout = p.stdout
if err := p.cmd.Start(); err != nil {
p.mu.Unlock()
return fmt.Errorf("error starting algorithm: %v", err)
}
p.mu.Unlock()
if err := p.cmd.Wait(); err != nil {
return fmt.Errorf("algorithm execution error: %v", err)
}
if err := os.RemoveAll(venvPath); err != nil {
return fmt.Errorf("error removing virtual environment: %v", err)
}
return nil
}
func (p *python) Stop() error {
if p.cmd == nil {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
if p.cmd.ProcessState != nil && p.cmd.ProcessState.Exited() {
if p.cmd == nil {
return nil
}
@@ -118,7 +125,7 @@ func (p *python) Stop() error {
return nil
}
if err := p.cmd.Process.Kill(); err != nil {
if err := p.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
return fmt.Errorf("error stopping algorithm: %v", err)
}
+91
View File
@@ -8,10 +8,13 @@ import (
"io"
"log/slog"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events/mocks"
"google.golang.org/grpc/metadata"
@@ -146,3 +149,91 @@ func TestRunWithRequirements(t *testing.T) {
t.Errorf("Expected output to contain requests version 2.26.0, got %q", stdout.String())
}
}
func TestStop(t *testing.T) {
t.Run("stop nil cmd", func(t *testing.T) {
p := &python{}
err := p.Stop()
if err != nil {
t.Errorf("Expected nil error, got %v", err)
}
})
t.Run("stop with running process", func(t *testing.T) {
p := &python{
stderr: io.Discard,
stdout: io.Discard,
}
p.cmd = exec.Command("python3", "-c", "import time; time.sleep(10)")
if err := p.cmd.Start(); err != nil {
t.Fatalf("Failed to start command: %v", err)
}
err := p.Stop()
if err != nil {
t.Errorf("Expected nil error, got %v", err)
}
// Verify it actually stopped
_ = p.cmd.Wait()
})
t.Run("stop already exited", func(t *testing.T) {
p := &python{}
p.cmd = exec.Command("python3", "-c", "print(1)")
if err := p.cmd.Run(); err != nil {
t.Fatal(err)
}
err := p.Stop()
if err != nil {
t.Errorf("Expected nil error, got %v", err)
}
})
}
func TestRun_Errors(t *testing.T) {
t.Run("invalid runtime error", func(t *testing.T) {
algo := &python{
algoFile: "algo.py",
runtime: "non-existent-python",
stderr: io.Discard,
stdout: io.Discard,
}
err := algo.Run()
assert.Error(t, err)
assert.Contains(t, err.Error(), "error creating virtual environment")
})
t.Run("pip install failure", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "python-err-test")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
scriptPath := filepath.Join(tmpDir, "test.py")
require.NoError(t, os.WriteFile(scriptPath, []byte("print(1)"), 0o644))
reqPath := filepath.Join(tmpDir, "requirements.txt")
require.NoError(t, os.WriteFile(reqPath, []byte("non-existent-package==9.9.9"), 0o644))
algo := &python{
algoFile: scriptPath,
requirementsFile: reqPath,
runtime: "python3",
stderr: io.Discard,
stdout: io.Discard,
}
err = algo.Run()
assert.Error(t, err)
assert.Contains(t, err.Error(), "error installing requirements")
})
}
func TestNewAlgorithmEmptyRuntime(t *testing.T) {
eventsSvc := new(mocks.Service)
algo := NewAlgorithm(slog.Default(), eventsSvc, "", "req.txt", "algo.py", nil, "")
p := algo.(*python)
if p.runtime != PyRuntime {
t.Errorf("Expected default runtime %s, got %s", PyRuntime, p.runtime)
}
}
+14 -6
View File
@@ -3,16 +3,21 @@
package wasm
import (
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"sync"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events"
)
var execCommand = exec.Command
const wasmRuntime = "wasmedge"
var mapDirOption = []string{"--dir", ".:" + algorithm.ResultsDir}
@@ -25,6 +30,7 @@ type wasm struct {
stdout io.Writer
args []string
cmd *exec.Cmd
mu sync.Mutex
}
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, args []string, algoFile, cmpID string) algorithm.Algorithm {
@@ -39,13 +45,16 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, args []string,
func (w *wasm) Run() error {
args := append(mapDirOption, w.algoFile)
args = append(args, w.args...)
w.cmd = exec.Command(wasmRuntime, args...)
w.mu.Lock()
w.cmd = execCommand(wasmRuntime, args...)
w.cmd.Stderr = w.stderr
w.cmd.Stdout = w.stdout
if err := w.cmd.Start(); err != nil {
w.mu.Unlock()
return fmt.Errorf("error starting algorithm: %v", err)
}
w.mu.Unlock()
if err := w.cmd.Wait(); err != nil {
return fmt.Errorf("algorithm execution error: %v", err)
@@ -55,11 +64,10 @@ func (w *wasm) Run() error {
}
func (w *wasm) Stop() error {
if w.cmd == nil {
return nil
}
w.mu.Lock()
defer w.mu.Unlock()
if w.cmd.ProcessState != nil && w.cmd.ProcessState.Exited() {
if w.cmd == nil {
return nil
}
@@ -67,7 +75,7 @@ func (w *wasm) Stop() error {
return nil
}
if err := w.cmd.Process.Kill(); err != nil {
if err := w.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
return fmt.Errorf("error stopping algorithm: %v", err)
}
+97 -7
View File
@@ -7,15 +7,18 @@ import (
"os"
"os/exec"
"testing"
"time"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
const testWasm = "test.wasm"
func TestNewAlgorithm(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
algoFile := "test.wasm"
algoFile := testWasm
args := []string{"arg1", "arg2"}
algo := NewAlgorithm(logger, eventsSvc, args, algoFile, "")
@@ -49,14 +52,18 @@ func TestRunError(t *testing.T) {
execCommand = mockExecCommandError
defer func() { execCommand = exec.Command }()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
algoFile := "test.wasm"
algoFile := testWasm
args := []string{"arg1", "arg2"}
w := NewAlgorithm(logger, eventsSvc, args, algoFile, "").(*wasm)
w := &wasm{
algoFile: algoFile,
args: args,
stderr: os.Stderr, // Use real stderr or io.Discard
stdout: os.Stdout,
}
err := w.Run()
if err == nil {
t.Errorf("Run() should have returned an error")
}
@@ -76,14 +83,97 @@ func mockExecCommandError(command string, args ...string) *exec.Cmd {
return cmd
}
func TestStop(t *testing.T) {
t.Run("stop nil cmd", func(t *testing.T) {
w := &wasm{}
err := w.Stop()
if err != nil {
t.Errorf("Expected nil error, got %v", err)
}
})
t.Run("stop with running process", func(t *testing.T) {
oldExecCommand := execCommand
execCommand = mockExecCommand
defer func() { execCommand = oldExecCommand }()
w := &wasm{
algoFile: testWasm,
stdout: os.Stdout,
stderr: os.Stderr,
}
// We need to simulate a running process.
// mockExecCommand returns a command that runs TestHelperProcess.
// If we don't call Wait(), it keeps running? No, TestHelperProcess exits immediately.
// Let's modify TestHelperProcess to sleep if an env var is set.
w.cmd = mockExecCommand("sleep", "10")
w.cmd.Env = append(w.cmd.Env, "GO_WANT_HELPER_PROCESS_SLEEP=1")
if err := w.cmd.Start(); err != nil {
t.Fatalf("Failed to start command: %v", err)
}
err := w.Stop()
if err != nil {
t.Errorf("Expected nil error, got %v", err)
}
_ = w.cmd.Wait()
})
}
func TestStopAlreadyExited(t *testing.T) {
oldExecCommand := execCommand
execCommand = mockExecCommand
defer func() { execCommand = oldExecCommand }()
w := &wasm{
algoFile: testWasm,
stdout: os.Stdout,
stderr: os.Stderr,
}
w.cmd = mockExecCommand("true")
if err := w.cmd.Run(); err != nil {
t.Fatalf("Failed to run command: %v", err)
}
err := w.Stop()
if err != nil {
t.Errorf("Expected nil error, got %v", err)
}
}
func TestRunSuccess(t *testing.T) {
oldExecCommand := execCommand
execCommand = mockExecCommand
defer func() { execCommand = oldExecCommand }()
algoFile := testWasm
args := []string{"arg1", "arg2"}
w := &wasm{
algoFile: algoFile,
args: args,
stderr: os.Stderr,
stdout: os.Stdout,
}
err := w.Run()
if err != nil {
t.Errorf("Run() returned unexpected error: %v", err)
}
}
func TestHelperProcess(t *testing.T) {
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
return
}
if os.Getenv("GO_WANT_HELPER_PROCESS_SLEEP") == "1" {
time.Sleep(10 * time.Second)
}
if os.Getenv("GO_WANT_HELPER_PROCESS_ERROR") == "1" {
os.Exit(1)
}
os.Exit(0)
}
var execCommand = exec.Command
+16 -3
View File
@@ -8,6 +8,7 @@ import (
"log/slog"
"net"
"os"
"sync"
"github.com/ultravioletrs/cocos/agent"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
@@ -15,6 +16,8 @@ import (
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/health"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
)
@@ -29,6 +32,7 @@ type AgentServer interface {
}
type agentServer struct {
mu sync.Mutex
gs *grpc.Server
logger *slog.Logger
svc agent.Service
@@ -62,10 +66,17 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
// Internal Unix socket is pure plaintext HTTP/2; Ingress Proxy handles external aTLS termination
grpcServerOptions = append(grpcServerOptions, grpc.Creds(insecure.NewCredentials()))
as.mu.Lock()
as.gs = grpc.NewServer(grpcServerOptions...)
gs := as.gs
as.mu.Unlock()
reflection.Register(as.gs)
agent.RegisterAgentServiceServer(as.gs, agentgrpc.NewServer(as.svc))
reflection.Register(gs)
agent.RegisterAgentServiceServer(gs, agentgrpc.NewServer(as.svc))
healthServer := health.NewServer()
healthServer.SetServingStatus("agent", grpc_health_v1.HealthCheckResponse_SERVING)
grpc_health_v1.RegisterHealthServer(gs, healthServer)
socketPath := as.host
if socketPath == "" || socketPath == "0.0.0.0" {
@@ -89,7 +100,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
as.logger.Info(fmt.Sprintf("agent service gRPC server listening at %s without TLS", socketPath))
go func() {
err := as.gs.Serve(listener)
err := gs.Serve(listener)
if err != nil && err != grpc.ErrServerStopped {
as.logger.Error(fmt.Sprintf("failed to start grpc server %s", err.Error()))
}
@@ -99,6 +110,8 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
}
func (as *agentServer) Stop() error {
as.mu.Lock()
defer as.mu.Unlock()
if as.gs != nil {
as.gs.GracefulStop()
}
+10
View File
@@ -78,6 +78,11 @@ func (s *RunnerService) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunRes
if err := f.Close(); err != nil {
return nil, fmt.Errorf("error closing file: %v", err)
}
defer func() {
if err := os.Remove(algoPath); err != nil {
s.logger.Warn("error removing algorithm file", "error", err)
}
}()
var algo algorithm.Algorithm
@@ -91,6 +96,11 @@ func (s *RunnerService) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunRes
if err != nil {
return nil, fmt.Errorf("error creating requirments file: %v", err)
}
defer func() {
if err := os.Remove(fr.Name()); err != nil {
s.logger.Warn("error removing requirements file", "error", err)
}
}()
if _, err := fr.Write(req.Requirements); err != nil {
return nil, fmt.Errorf("error writing requirements to file: %v", err)
}
+109 -10
View File
@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"testing"
@@ -43,6 +44,11 @@ func TestNewRunnerService(t *testing.T) {
// TestRunWithBinaryAlgorithm tests running a binary algorithm.
func TestRunWithBinaryAlgorithm(t *testing.T) {
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
@@ -80,6 +86,9 @@ func TestRunWithPythonAlgorithm(t *testing.T) {
require.NotNil(t, resp)
assert.Empty(t, resp.Error)
assert.Equal(t, "test-python", resp.ComputationId)
t.Cleanup(func() {
_ = os.Remove("algo")
})
}
// TestRunWithPythonAlgorithmNoRequirements tests running Python without requirements.
@@ -100,6 +109,9 @@ func TestRunWithPythonAlgorithmNoRequirements(t *testing.T) {
require.NotNil(t, resp)
assert.Empty(t, resp.Error)
assert.Equal(t, "test-python-noreq", resp.ComputationId)
t.Cleanup(func() {
_ = os.Remove("algo")
})
}
// TestRunWithWasmAlgorithm tests running a WASM algorithm.
@@ -123,6 +135,9 @@ func TestRunWithWasmAlgorithm(t *testing.T) {
t.Skip("wasmedge not found, skipping test")
}
assert.Equal(t, "test-wasm", resp.ComputationId)
t.Cleanup(func() {
_ = os.Remove("algo")
})
}
// TestRunWithDockerAlgorithm tests running a Docker algorithm.
@@ -146,6 +161,9 @@ func TestRunWithDockerAlgorithm(t *testing.T) {
t.Skip("Docker issue, skipping test")
}
assert.Equal(t, "test-docker", resp.ComputationId)
t.Cleanup(func() {
_ = os.Remove("algo")
})
}
// TestRunWithUnsupportedAlgorithmType tests running with unsupported algorithm type.
@@ -193,6 +211,9 @@ func TestRunAlreadyRunning(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, "computation already running", resp.Error)
t.Cleanup(func() {
_ = os.Remove("algo")
})
}
// TestStopWhenRunning tests stopping a running computation.
@@ -208,8 +229,12 @@ func TestStopWhenRunning(t *testing.T) {
Args: []string{},
}
_, err := rs.Run(context.Background(), req)
require.NoError(t, err)
go func() {
_, _ = rs.Run(context.Background(), req)
}()
// Give it time to start
time.Sleep(500 * time.Millisecond)
stopReq := &pb.StopRequest{
ComputationId: "test-stop",
@@ -218,21 +243,72 @@ func TestStopWhenRunning(t *testing.T) {
stopResp, err := rs.Stop(context.Background(), stopReq)
require.NoError(t, err)
require.NotNil(t, stopResp)
t.Cleanup(func() {
_ = os.Remove("algo")
})
}
// TestStopWhenNotRunning tests stopping when no computation is running.
func TestStopWhenNotRunning(t *testing.T) {
// TestRunErrors tests error paths in Run.
func TestRunErrors(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
stopReq := &pb.StopRequest{
ComputationId: "test-not-running",
}
t.Run("create algo file failure", func(t *testing.T) {
// Create a directory named "algo" to make os.Create("algo") fail
err := os.Mkdir("algo", 0o755)
require.NoError(t, err)
defer os.RemoveAll("algo")
stopResp, err := rs.Stop(context.Background(), stopReq)
require.NoError(t, err)
require.NotNil(t, stopResp)
req := &pb.RunRequest{
ComputationId: "test-err",
AlgoType: "bin",
Algorithm: []byte("test"),
}
_, err = rs.Run(context.Background(), req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "error creating algorithm file")
})
t.Run("getwd failure", func(t *testing.T) {
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
err := os.Chdir(tmpDir)
require.NoError(t, err)
// Remove the current working directory to trigger Getwd failure
err = os.RemoveAll(tmpDir)
require.NoError(t, err)
req := &pb.RunRequest{
ComputationId: "test-err-getwd",
AlgoType: "bin",
Algorithm: []byte("test"),
}
_, err = rs.Run(context.Background(), req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "error getting current directory")
// Restore working directory
_ = os.Chdir(origDir)
})
t.Run("requirements file creation failure", func(t *testing.T) {
// This one is harder because it uses os.CreateTemp("", "requirements.txt")
// We can't easily make this fail without reaching into the system's temp dir.
// Skipping for now as it's a very unlikely edge case.
})
t.Run("chmod failure", func(t *testing.T) {
// We can't easily mock os.Chmod, but we can try to make the file unmodifiable
// On Linux, we can set the immutable attribute, but that requires root.
// Alternatively, we can try to use a directory with permissions that prevent chmod?
// No, chmod usually works if you own the file.
})
t.Run("write algorithm failure", func(t *testing.T) {
// This is also hard without mocking os.File.Write or reaching internal limits.
})
}
// TestConcurrentRun tests that concurrent runs are properly serialized.
@@ -260,6 +336,9 @@ func TestConcurrentRun(t *testing.T) {
resp2, err := rs.Run(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, "computation already running", resp2.Error)
t.Cleanup(func() {
_ = os.Remove("algo")
})
}
// TestRunWithMultipleArgs tests running with multiple arguments.
@@ -280,4 +359,24 @@ func TestRunWithMultipleArgs(t *testing.T) {
require.NotNil(t, resp)
assert.Empty(t, resp.Error)
assert.Equal(t, "test-multi-args", resp.ComputationId)
t.Cleanup(func() {
_ = os.Remove("algo")
})
}
func TestStopFailure(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
// Mock an algorithm that fails on Stop
rs.currentAlgo = &MockAlgorithmStopFail{}
_, err := rs.Stop(context.Background(), &pb.StopRequest{})
assert.Error(t, err)
}
type MockAlgorithmStopFail struct{}
func (m *MockAlgorithmStopFail) Run() error { return nil }
func (m *MockAlgorithmStopFail) Stop() error { return fmt.Errorf("stop failed") }
+91 -31
View File
@@ -130,6 +130,7 @@ type Service interface {
type OCIClient interface {
PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error
ToDockerArchive(ctx context.Context, ociDir, destFile string) error
}
type agentService struct {
@@ -297,6 +298,10 @@ func (as *agentService) StopComputation(ctx context.Context) error {
return fmt.Errorf("error removing results directory: %v", err)
}
if err := os.Remove("algo"); err != nil && !os.IsNotExist(err) {
as.logger.Warn("error removing algorithm file", "error", err)
}
as.sm.Reset(Idle)
as.computation = Computation{}
@@ -477,7 +482,8 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
res, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset")
if err != nil {
as.logger.Error("failed to download and decrypt dataset", "error", err, "filename", d.Filename)
as.runError = fmt.Errorf("failed to download and decrypt dataset %s: %w", d.Filename, err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
@@ -485,7 +491,8 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
// Verify hash
hash := sha3.Sum256(res.Data)
if hash != d.Hash {
as.logger.Error("dataset hash mismatch", "filename", d.Filename)
as.runError = fmt.Errorf("dataset %s hash mismatch: expected %x, got %x", d.Filename, d.Hash, hash)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
@@ -500,7 +507,8 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
if d.Decompress {
if err := internal.UnzipFromMemory(res.Data, algorithm.DatasetsDir); err != nil {
as.logger.Error("error decompressing dataset", "error", err, "filename", d.Filename)
as.runError = fmt.Errorf("failed to unzip dataset %s: %w", d.Filename, err)
as.logger.Error(as.runError.Error())
as.sm.SendEvent(RunFailed)
return
}
@@ -594,48 +602,84 @@ func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *
// Extract algorithm file from OCI layers
extractDir := filepath.Join(os.TempDir(), "cocos-oci", "extracted", sanitizedName)
var algorithmPath string
var requirementsPath string
var err error
var files []string
if resourceType == "algorithm" {
algorithmPath, err = oci.ExtractAlgorithm(ctx, as.logger, destDir, extractDir)
if err != nil {
return nil, fmt.Errorf("failed to extract algorithm from OCI image: %w", err)
if as.computation.Algorithm.AlgoType == string(algorithm.AlgoTypeDocker) {
// For Docker algorithms, convert OCI image to Docker archive tarball
algorithmPath = filepath.Join(extractDir, "image.tar")
if err := os.MkdirAll(extractDir, 0o755); err != nil {
return nil, fmt.Errorf("failed to create extract directory: %w", err)
}
if err := as.ociClient.ToDockerArchive(ctx, destDir, algorithmPath); err != nil {
return nil, fmt.Errorf("failed to convert OCI image to Docker archive: %w", err)
}
as.logger.Info("OCI image converted to Docker archive", "path", algorithmPath)
files = []string{algorithmPath}
} else {
algorithmPath, requirementsPath, err = oci.ExtractAlgorithm(ctx, as.logger, destDir, extractDir, as.computation.Algorithm.AlgoType)
if err != nil {
return nil, fmt.Errorf("failed to extract algorithm from OCI image: %w", err)
}
as.logger.Info("algorithm extracted from OCI image", "path", algorithmPath)
files = []string{algorithmPath}
}
as.logger.Info("algorithm extracted from OCI image", "path", algorithmPath)
} else {
// Assume dataset
files, err := oci.ExtractDataset(destDir, extractDir)
files, err = oci.ExtractDataset(destDir, extractDir)
if err != nil || len(files) == 0 {
return nil, fmt.Errorf("failed to extract dataset from OCI image: %w", err)
}
// For now, take the first file found.
// nolint:godox // TODO: Handle multiple files / directory structure if needed.
// Set algorithmPath to the first file for SourceDir calculation later
algorithmPath = files[0]
as.logger.Info("dataset extracted from OCI image", "path", algorithmPath)
as.logger.Info("dataset extracted from OCI image", "num_files", len(files))
}
// Read algorithm file
algorithmData, err := os.ReadFile(algorithmPath)
// Determine which path to hash based on extraction results
var hashPath string
// For algorithms, we always hash the specific algorithm file found.
// For datasets, if there's only one file, hash it directly.
// If multiple files, hash the directory (which zips it).
if len(files) == 1 {
hashPath = files[0]
} else {
hashPath = extractDir
}
// Calculate digest (matches internal.Checksum logic)
resourceData, _, err := internal.Digest(hashPath)
if err != nil {
return nil, fmt.Errorf("failed to read algorithm file: %w", err)
return nil, fmt.Errorf("failed to calculate resource digest: %w", err)
}
// Check for requirements.txt if algorithm
// Read requirements file if found (only for algorithms)
var reqData []byte
if resourceType == "algorithm" {
reqPath := filepath.Join(filepath.Dir(algorithmPath), "requirements.txt")
if data, err := os.ReadFile(reqPath); err == nil {
reqData = data
as.logger.Info("found requirements.txt", "size", len(data))
if requirementsPath != "" {
reqData, err = os.ReadFile(requirementsPath)
if err != nil {
as.logger.Warn("failed to read requirements file", "path", requirementsPath, "error", err)
} else {
as.logger.Info("requirements.txt loaded", "size", len(reqData))
}
} else {
// Fallback: check if requirements.txt exists in the same directory as the algorithm
reqPath := filepath.Join(filepath.Dir(algorithmPath), "requirements.txt")
if data, err := os.ReadFile(reqPath); err == nil {
reqData = data
as.logger.Info("found requirements.txt via fallback", "size", len(data))
}
}
}
as.logger.Info("algorithm loaded", "size", len(algorithmData))
as.logger.Info("resource loaded from OCI", "type", resourceType, "size", len(resourceData), "hash_path", hashPath)
return &DecryptedResource{
Data: algorithmData,
Data: resourceData,
Requirements: reqData,
SourceDir: filepath.Dir(algorithmPath),
SourceDir: extractDir,
}, nil
}
@@ -845,6 +889,8 @@ func (as *agentService) runComputation(state statemachine.State) {
as.publishEvent(Starting.String())(state)
as.logger.Debug("computation run started")
defer func() {
as.mu.Lock()
defer as.mu.Unlock()
if as.runError != nil {
as.sm.SendEvent(RunFailed)
} else {
@@ -852,12 +898,9 @@ func (as *agentService) runComputation(state statemachine.State) {
}
}()
if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil {
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
}
// Read algo file
currentDir, _ := os.Getwd()
algoFile := filepath.Join(currentDir, "algo")
defer func() {
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
@@ -866,14 +909,25 @@ func (as *agentService) runComputation(state statemachine.State) {
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
}
if err := os.Remove(algoFile); err != nil && !os.IsNotExist(err) {
as.logger.Warn(fmt.Sprintf("error removing algorithm file: %s", err.Error()))
}
}()
// Read algo file
currentDir, _ := os.Getwd()
algoFile := filepath.Join(currentDir, "algo")
if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil {
as.mu.Lock()
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
as.mu.Unlock()
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
}
algoBytes, err := os.ReadFile(algoFile)
if err != nil {
as.mu.Lock()
as.runError = fmt.Errorf("failed to read algo file: %w", err)
as.mu.Unlock()
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
@@ -891,14 +945,18 @@ func (as *agentService) runComputation(state statemachine.State) {
// Datasets implicit on shared FS
})
if err != nil {
as.mu.Lock()
as.runError = err
as.mu.Unlock()
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
}
if resp.Error != "" {
as.mu.Lock()
as.runError = errors.New(resp.Error)
as.mu.Unlock()
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error))
as.publishEvent(Failed.String())(state)
return
@@ -906,7 +964,9 @@ func (as *agentService) runComputation(state statemachine.State) {
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
if err != nil {
as.mu.Lock()
as.runError = err
as.mu.Unlock()
as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
+523 -6
View File
@@ -4,6 +4,8 @@ package agent
import (
"archive/tar"
"archive/zip"
"bytes"
"compress/gzip"
"context"
"crypto/rand"
@@ -46,6 +48,11 @@ func (m *MockOCIClient) PullAndDecrypt(ctx context.Context, source oci.ResourceS
return args.Error(0)
}
func (m *MockOCIClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error {
args := m.Called(ctx, ociDir, destFile)
return args.Error(0)
}
var (
algoPath = "../test/manual/algo/lin_reg.py"
reqPath = "../test/manual/algo/requirements.txt"
@@ -1097,7 +1104,7 @@ func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
algoContent := []byte("print('hello')")
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "main.py", string(algoContent))
setupMinimalOCI(t, destDir, "main.py", algoContent)
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
@@ -1112,7 +1119,7 @@ func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/image",
URL: "docker://test/algo-success",
},
},
KBS: KBSConfig{Enabled: true},
@@ -1131,7 +1138,57 @@ func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
mockOCI.AssertExpectations(t)
}
func setupMinimalOCI(t *testing.T, ociDir, filename, content string) {
func TestDownloadAlgorithmIfRemote_Docker_Success(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", AlgorithmReceived).Return().Once()
mockOCI := new(MockOCIClient)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Return(nil)
dummyContent := []byte("dummy docker tar")
dummyHash := sha3.Sum256(dummyContent)
mockOCI.On("ToDockerArchive", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destFile := args.String(2)
err := os.WriteFile(destFile, dummyContent, 0o644)
require.NoError(t, err)
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Algorithm: Algorithm{
AlgoType: "docker",
Hash: dummyHash,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-docker-success",
},
},
KBS: KBSConfig{Enabled: true},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Nil(t, svc.runError)
assert.True(t, svc.algoReceived)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func setupMinimalOCI(t *testing.T, ociDir, filename string, content []byte) {
t.Helper()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
@@ -1149,7 +1206,8 @@ func setupMinimalOCI(t *testing.T, ociDir, filename, content string) {
Size: int64(len(content)),
}
require.NoError(t, tw.WriteHeader(hdr))
_, err = tw.Write([]byte(content))
_, err = tw.Write(content)
require.NoError(t, err)
require.NoError(t, tw.Close())
@@ -1201,7 +1259,7 @@ func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
dataContent := []byte("a,b,c\n1,2,3")
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.csv", string(dataContent))
setupMinimalOCI(t, destDir, "data.csv", dataContent)
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
@@ -1217,7 +1275,7 @@ func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
Hash: dataHash,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/image",
URL: "docker://test/data-success",
},
},
},
@@ -1234,3 +1292,462 @@ func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestDownloadDatasetsIfRemote_Decompress(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", DataReceived).Return().Maybe()
sm.On("SendEvent", RunFailed).Return().Maybe()
mockOCI := new(MockOCIClient)
// Create a zip file in memory
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
f, err := zw.Create("test.txt")
require.NoError(t, err)
_, err = f.Write([]byte("hello zip"))
require.NoError(t, err)
require.NoError(t, zw.Close())
zipData := buf.Bytes()
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.zip", zipData)
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
dataHash := sha3.Sum256(zipData)
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.zip",
Hash: dataHash,
Decompress: true,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-decompress",
},
},
},
KBS: KBSConfig{Enabled: true},
}
err = os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
svc.downloadDatasetsIfRemote(ReceivingData)
assert.Nil(t, svc.runError)
assert.Len(t, svc.computation.Datasets, 0)
// Check if file was decompressed
decompressedFile := filepath.Join(algorithm.DatasetsDir, "test.txt")
_, err = os.Stat(decompressedFile)
assert.NoError(t, err)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) {
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
t.Run("hash mismatch", func(t *testing.T) {
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
mockOCI := new(MockOCIClient)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "main.py", []byte("wrong content"))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Algorithm: Algorithm{
Hash: sha3.Sum256([]byte("expected content")),
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-hash-mismatch",
},
},
KBS: KBSConfig{Enabled: true},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "algorithm hash mismatch")
sm.AssertExpectations(t)
})
t.Run("create algo file failure", func(t *testing.T) {
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
// Create a directory named "algo" to make file creation fail
require.NoError(t, os.Mkdir("algo", 0o755))
defer os.RemoveAll("algo")
mockOCI := new(MockOCIClient)
algoContent := "print(1)"
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "main.py", []byte(algoContent))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Algorithm: Algorithm{
Hash: sha3.Sum256([]byte(algoContent)),
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-create-fail",
},
},
KBS: KBSConfig{Enabled: true},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "error creating algorithm file")
sm.AssertExpectations(t)
})
t.Run("extraction failure", func(t *testing.T) {
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
mockOCI := new(MockOCIClient)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
// Setup OCI with NO main.py or any algorithm file
require.NoError(t, os.MkdirAll(filepath.Join(destDir, "blobs"), 0o755))
// Create a legit-looking but empty index.json
require.NoError(t, os.WriteFile(filepath.Join(destDir, "index.json"), []byte(`{"schemaVersion":2,"manifests":[]}`), 0o644))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Algorithm: Algorithm{
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/image",
},
},
KBS: KBSConfig{Enabled: true},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "no manifests found")
sm.AssertExpectations(t)
})
}
func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) {
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
// Use a fresh mock in each subtest to avoid state pollution
t.Run("dataset create file failure", func(t *testing.T) {
eventsSvc := mocks.NewService(t)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
// Create a directory named "data.csv" in datasets dir to make file creation fail
require.NoError(t, os.MkdirAll(filepath.Join(algorithm.DatasetsDir, "data.csv"), 0o755))
defer os.RemoveAll(algorithm.DatasetsDir)
mockOCI := new(MockOCIClient)
dataContent := "a,b,c"
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Hash: sha3.Sum256([]byte(dataContent)),
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-create-fail",
},
},
},
KBS: KBSConfig{Enabled: true},
}
svc.downloadDatasetsIfRemote(ReceivingData)
sm.AssertExpectations(t)
})
t.Run("dataset hash mismatch", func(t *testing.T) {
eventsSvc := mocks.NewService(t)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { _ = os.Chdir(origDir) }()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
mockOCI := new(MockOCIClient)
dataContent := "wrong content"
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Hash: sha3.Sum256([]byte("expected content")),
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-mismatch",
},
},
},
KBS: KBSConfig{Enabled: true},
}
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
svc.downloadDatasetsIfRemote(ReceivingData)
if svc.runError == nil {
t.Fatalf("runError should not be nil in hash mismatch test")
}
assert.Contains(t, svc.runError.Error(), "dataset data.csv hash mismatch")
sm.AssertExpectations(t)
})
t.Run("dataset unzip failure", func(t *testing.T) {
eventsSvc := mocks.NewService(t)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { _ = os.Chdir(origDir) }()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
mockOCI := new(MockOCIClient)
// Provide invalid zip content
dataContent := "not a zip file"
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.zip", []byte(dataContent))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.zip",
Hash: sha3.Sum256([]byte(dataContent)),
Decompress: true,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-unzip-fail",
},
},
},
KBS: KBSConfig{Enabled: true},
}
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
svc.downloadDatasetsIfRemote(ReceivingData)
if svc.runError == nil {
t.Fatalf("runError should not be nil in unzip failure test")
}
assert.Contains(t, svc.runError.Error(), "failed to unzip dataset")
sm.AssertExpectations(t)
})
}
func TestAlgo_RemoteSource(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("GetState").Return(ReceivingAlgorithm)
sm.On("SendEvent", AlgorithmReceived).Return().Once()
mockOCI := new(MockOCIClient)
algoContent := []byte("print('remote algo')")
algoHash := sha3.Sum256(algoContent)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "main.py", algoContent)
}).Return(nil)
svc := &agentService{
logger: slog.Default(),
eventSvc: eventsSvc,
sm: sm,
ociClient: mockOCI,
computation: Computation{
Algorithm: Algorithm{
Hash: algoHash,
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-remote",
},
},
KBS: KBSConfig{Enabled: true},
},
}
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, "python"))
err := svc.Algo(ctx, Algorithm{})
assert.NoError(t, err)
assert.True(t, svc.algoReceived)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestData_RemoteSource(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("GetState").Return(ReceivingData)
sm.On("SendEvent", DataReceived).Return().Once()
mockOCI := new(MockOCIClient)
dataContent := []byte("remote data")
dataHash := sha3.Sum256(dataContent)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.csv", dataContent)
}).Return(nil)
svc := &agentService{
logger: slog.Default(),
eventSvc: eventsSvc,
sm: sm,
ociClient: mockOCI,
computation: Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Hash: dataHash,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-remote",
},
},
},
KBS: KBSConfig{Enabled: true},
},
}
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
ctx := context.Background()
err = svc.Data(ctx, Dataset{})
assert.NoError(t, err)
assert.Len(t, svc.computation.Datasets, 0)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestRunComputation_Success(t *testing.T) {
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
// Write a dummy algo file
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunComplete).Return().Once()
svc := &agentService{
logger: slog.Default(),
eventSvc: eventsSvc,
sm: sm,
runnerClient: runnerCli,
computation: Computation{ID: "test-run"},
}
svc.runComputation(Running)
assert.Nil(t, svc.runError)
sm.AssertExpectations(t)
runnerCli.AssertExpectations(t)
}
+2
View File
@@ -144,6 +144,8 @@ func TestNewGetAttestationCmd(t *testing.T) {
t.Cleanup(func() {
os.Remove(attestationFilePath)
os.Remove(attestationReportJson)
os.Remove(azureAttestResultFilePath)
os.Remove(azureAttestTokenFilePath)
})
mockSDK := new(mocks.SDK)
cli := &CLI{agentSDK: mockSDK}
+16 -14
View File
@@ -52,27 +52,29 @@ func DeleteFilesInDir(dirPath string) error {
// Checksum calculates the SHA3-256 checksum of the file or directory at path.
func Checksum(path string) ([]byte, error) {
_, sum, err := Digest(path)
return sum, err
}
// Digest returns the data used for checksumming and the checksum itself.
func Digest(path string) ([]byte, []byte, error) {
file, err := os.Stat(path)
if err != nil {
return nil, err
return nil, nil, err
}
var data []byte
if file.IsDir() {
f, err := ZipDirectoryToMemory(path)
if err != nil {
return nil, err
}
sum := sha3.Sum256(f)
return sum[:], nil
data, err = ZipDirectoryToMemory(path)
} else {
f, err := os.ReadFile(path)
if err != nil {
return nil, err
}
sum := sha3.Sum256(f)
return sum[:], nil
data, err = os.ReadFile(path)
}
if err != nil {
return nil, nil, err
}
sum := sha3.Sum256(data)
return data, sum[:], nil
}
// ChecksumHex calculates the SHA3-256 checksum of the file or directory at path and returns it as a hex-encoded string.
+56
View File
@@ -8,6 +8,9 @@ import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCopyFile(t *testing.T) {
@@ -149,6 +152,59 @@ func TestChecksumHex(t *testing.T) {
}
}
func TestCopyFile_ErrorPaths(t *testing.T) {
tempDir, err := os.MkdirTemp("", "copyfile_err_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
srcPath := filepath.Join(tempDir, "src.txt")
require.NoError(t, os.WriteFile(srcPath, []byte("test"), 0o644))
t.Run("destination folder creation failure", func(t *testing.T) {
// Create a file where a directory should be
blockedDir := filepath.Join(tempDir, "blocked_dir")
require.NoError(t, os.WriteFile(blockedDir, []byte("file"), 0o644))
dstPath := filepath.Join(blockedDir, "dst.txt")
err := CopyFile(srcPath, dstPath)
assert.Error(t, err)
})
t.Run("source file open failure", func(t *testing.T) {
err := CopyFile("/non/existent/src", filepath.Join(tempDir, "dst.txt"))
assert.Error(t, err)
})
}
func TestDeleteFilesInDir_ErrorPaths(t *testing.T) {
t.Run("non-existent directory", func(t *testing.T) {
err := DeleteFilesInDir("/non/existent/path")
// os.ReadDir on non-existent path returns error, but function returns nil
assert.NoError(t, err)
})
}
func TestDigest_ErrorPaths(t *testing.T) {
t.Run("non-existent file", func(t *testing.T) {
_, _, err := Digest("/non/existent/path")
assert.Error(t, err)
})
t.Run("directory digest failure", func(t *testing.T) {
// This happens if some file inside the directory cannot be read
tempDir, err := os.MkdirTemp("", "digest_err_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
blockedFile := filepath.Join(tempDir, "blocked.txt")
require.NoError(t, os.WriteFile(blockedFile, []byte("test"), 0o000)) // No permissions
_, _, err = Digest(tempDir)
// ZipDirectoryToMemory should fail due to permission error
assert.Error(t, err)
})
}
func TestChecksumHex_NonExistentFile(t *testing.T) {
_, err := ChecksumHex("nonexistent.txt")
if err == nil {
+72 -39
View File
@@ -8,49 +8,63 @@ import (
"io"
"os"
"path/filepath"
"sort"
"time"
)
func ZipDirectoryToMemory(sourceDir string) ([]byte, error) {
buf := new(bytes.Buffer)
zipWriter := zip.NewWriter(buf)
var files []string
err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
files = append(files, path)
}
return nil
})
if err != nil {
return nil, err
}
if info.IsDir() {
return nil
sort.Strings(files)
for _, path := range files {
info, err := os.Stat(path)
if err != nil {
return nil, err
}
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
return err
return nil, err
}
zipHeader, err := zip.FileInfoHeader(info)
if err != nil {
return err
return nil, err
}
zipHeader.Name = relPath
zipHeader.Modified = time.Unix(0, 0) // Deterministic timestamp
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
if err != nil {
return err
return nil, err
}
fileToZip, err := os.Open(path)
if err != nil {
return err
return nil, err
}
defer fileToZip.Close()
_, err = io.Copy(zipWriterEntry, fileToZip)
return err
})
if err != nil {
zipWriter.Close()
return nil, err
if err != nil {
return nil, err
}
}
if err := zipWriter.Close(); err != nil {
@@ -68,45 +82,64 @@ func ZipDirectoryToTempFile(sourceDir string) (*os.File, error) {
zipWriter := zip.NewWriter(tmpFile)
var files []string
err = filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
if !info.IsDir() {
files = append(files, path)
}
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
return err
}
zipHeader, err := zip.FileInfoHeader(info)
if err != nil {
return err
}
zipHeader.Name = relPath
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
if err != nil {
return err
}
fileToZip, err := os.Open(path)
if err != nil {
return err
}
defer fileToZip.Close()
_, err = io.Copy(zipWriterEntry, fileToZip)
return err
return nil
})
if err != nil {
zipWriter.Close()
return nil, err
}
sort.Strings(files)
for _, path := range files {
info, err := os.Stat(path)
if err != nil {
zipWriter.Close()
return nil, err
}
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
zipWriter.Close()
return nil, err
}
zipHeader, err := zip.FileInfoHeader(info)
if err != nil {
zipWriter.Close()
return nil, err
}
zipHeader.Name = relPath
zipHeader.Modified = time.Unix(0, 0) // Deterministic timestamp
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
if err != nil {
zipWriter.Close()
return nil, err
}
fileToZip, err := os.Open(path)
if err != nil {
zipWriter.Close()
return nil, err
}
defer fileToZip.Close()
_, err = io.Copy(zipWriterEntry, fileToZip)
if err != nil {
zipWriter.Close()
return nil, err
}
}
if err := zipWriter.Close(); err != nil {
return nil, err
}
+190 -1
View File
@@ -4,10 +4,14 @@ package internal
import (
"archive/zip"
"bytes"
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestZipDirectoryToMemory(t *testing.T) {
@@ -226,6 +230,112 @@ func TestZipDirectoryToTempFile(t *testing.T) {
}
}
func TestUnzipFromMemory_ErrorPaths(t *testing.T) {
tempDir, err := os.MkdirTemp("", "unzip_error_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create a valid zip with one file
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
f, err := zw.Create("file.txt")
require.NoError(t, err)
_, err = f.Write([]byte("content"))
require.NoError(t, err)
require.NoError(t, zw.Close())
t.Run("mkdir failure", func(t *testing.T) {
// Create a file where a directory should be
blockedDir := filepath.Join(tempDir, "subdir")
require.NoError(t, os.WriteFile(blockedDir, []byte("blocked"), 0o644))
defer os.Remove(blockedDir)
// Create a zip that tries to create a file in that subdir
var buf2 bytes.Buffer
zw2 := zip.NewWriter(&buf2)
_, err := zw2.Create("subdir/file.txt")
require.NoError(t, err)
require.NoError(t, zw2.Close())
err = UnzipFromMemory(buf2.Bytes(), tempDir)
assert.Error(t, err)
})
t.Run("create file failure", func(t *testing.T) {
// Create a directory where a file should be
blockedFile := filepath.Join(tempDir, "blocked_file.txt")
require.NoError(t, os.MkdirAll(blockedFile, 0o755))
defer os.RemoveAll(blockedFile)
var buf2 bytes.Buffer
zw2 := zip.NewWriter(&buf2)
_, err := zw2.Create("blocked_file.txt")
require.NoError(t, err)
require.NoError(t, zw2.Close())
err = UnzipFromMemory(buf2.Bytes(), tempDir)
assert.Error(t, err)
})
}
func TestZipDirectoryToMemory_ErrorPaths(t *testing.T) {
t.Run("non-existent directory", func(t *testing.T) {
_, err := ZipDirectoryToMemory("/non/existent/path")
assert.Error(t, err)
})
}
func TestZipDirectoryToTempFile_ErrorPaths(t *testing.T) {
t.Run("non-existent directory", func(t *testing.T) {
_, err := ZipDirectoryToTempFile("/non/existent/path")
assert.Error(t, err)
})
t.Run("empty source path", func(t *testing.T) {
_, err := ZipDirectoryToTempFile("")
assert.Error(t, err)
})
}
func TestZipDirectoryToMemory_NotADirectory(t *testing.T) {
tempFile, err := os.CreateTemp("", "not_a_dir")
require.NoError(t, err)
defer os.Remove(tempFile.Name())
tempFile.Close()
_, err = ZipDirectoryToMemory(tempFile.Name())
// filepath.Walk on a file succeeds and visits only that file
assert.NoError(t, err)
}
func TestZipDirectoryToTempFile_NotADirectory(t *testing.T) {
tempFile, err := os.CreateTemp("", "not_a_dir_tempfile")
require.NoError(t, err)
defer os.Remove(tempFile.Name())
tempFile.Close()
zf, err := ZipDirectoryToTempFile(tempFile.Name())
assert.NoError(t, err)
if err == nil {
zf.Close()
os.Remove(zf.Name())
}
}
func TestZipDirectoryToMemory_OpenError(t *testing.T) {
tempDir, err := os.MkdirTemp("", "open_err_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
file := filepath.Join(tempDir, "unreadable.txt")
require.NoError(t, os.WriteFile(file, []byte("test"), 0o000))
_, err = ZipDirectoryToMemory(tempDir)
assert.Error(t, err)
}
func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
tests := []struct {
name string
@@ -240,7 +350,6 @@ func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
sourceDir: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ZipDirectoryToTempFile(tt.sourceDir)
@@ -250,3 +359,83 @@ func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
})
}
}
func TestZipDirectoryToTempFile_InternalErrorPaths(t *testing.T) {
t.Run("unreadable file in directory", func(t *testing.T) {
tempDir, err := os.MkdirTemp("", "unreadable_zip_temp")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
file := filepath.Join(tempDir, "unreadable.txt")
require.NoError(t, os.WriteFile(file, []byte("test"), 0o000))
_, err = ZipDirectoryToTempFile(tempDir)
assert.Error(t, err)
})
}
func TestUnzipFromMemory_MoreEdgeCases(t *testing.T) {
t.Run("empty zip data", func(t *testing.T) {
err := UnzipFromMemory([]byte{}, t.TempDir())
assert.Error(t, err)
})
t.Run("zip with absolute paths", func(t *testing.T) {
// This tests if UnzipFromMemory handles files with names like "/tmp/hacker.txt"
// zip.NewReader usually doesn't allow absolute paths easily, but let's see.
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
_, err := zw.Create("/tmp/test.txt")
require.NoError(t, err)
require.NoError(t, zw.Close())
tempDir := t.TempDir()
err = UnzipFromMemory(buf.Bytes(), tempDir)
assert.NoError(t, err)
// It should be joined with tempDir, not written to /tmp/test.txt
expectedPath := filepath.Join(tempDir, "/tmp/test.txt")
_, err = os.Stat(expectedPath)
assert.NoError(t, err)
})
}
func TestUnzipFromMemory_InvalidReader(t *testing.T) {
err := UnzipFromMemory([]byte("invalid"), t.TempDir())
assert.Error(t, err)
}
func TestUnzipFromMemory_FileCreateError(t *testing.T) {
tempDir := t.TempDir()
// Create a zip with one file
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
_, err := zw.Create("file.txt")
require.NoError(t, err)
require.NoError(t, zw.Close())
// Create a directory where the file should be
blockedFile := filepath.Join(tempDir, "file.txt")
require.NoError(t, os.MkdirAll(blockedFile, 0o755))
err = UnzipFromMemory(buf.Bytes(), tempDir)
assert.Error(t, err)
}
func TestUnzipFromMemory_DirCreateError(t *testing.T) {
tempDir := t.TempDir()
// Create a zip with one directory entry
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
_, err := zw.Create("subdir/")
require.NoError(t, err)
require.NoError(t, zw.Close())
// Create a file where the directory should be
blockedDir := filepath.Join(tempDir, "subdir")
require.NoError(t, os.WriteFile(blockedDir, []byte("blocked"), 0o644))
err = UnzipFromMemory(buf.Bytes(), tempDir)
assert.Error(t, err)
}
+72 -44
View File
@@ -31,22 +31,22 @@ type OCIIndex struct {
} `json:"manifests"`
}
// ExtractAlgorithm extracts the algorithm file from an OCI image directory.
func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath string) (string, error) {
// ExtractAlgorithm extracts the algorithm file and optionally requirements.txt from an OCI image directory.
func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath, algoType string) (string, string, error) {
// Read index.json to find manifest
indexPath := filepath.Join(ociDir, "index.json")
indexData, err := os.ReadFile(indexPath)
if err != nil {
return "", fmt.Errorf("failed to read index.json: %w", err)
return "", "", fmt.Errorf("failed to read index.json: %w", err)
}
var index OCIIndex
if err := json.Unmarshal(indexData, &index); err != nil {
return "", fmt.Errorf("failed to parse index.json: %w", err)
return "", "", fmt.Errorf("failed to parse index.json: %w", err)
}
if len(index.Manifests) == 0 {
return "", fmt.Errorf("no manifests found in index.json")
return "", "", fmt.Errorf("no manifests found in index.json")
}
// Get the first manifest digest
@@ -56,7 +56,7 @@ func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath
// Read manifest to find layers
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
return "", fmt.Errorf("failed to read manifest: %w", err)
return "", "", fmt.Errorf("failed to read manifest: %w", err)
}
var manifest struct {
@@ -65,50 +65,64 @@ func ExtractAlgorithm(ctx context.Context, logger *slog.Logger, ociDir, destPath
} `json:"layers"`
}
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return "", fmt.Errorf("failed to parse manifest: %w", err)
return "", "", fmt.Errorf("failed to parse manifest: %w", err)
}
// Extract layers to find algorithm files
logger.Debug("found layers in manifest", "count", len(manifest.Layers))
var algorithmPath string
var requirementsPath string
var allSeenFiles []string
// Iterate layers in reverse order to find user code first (usually in top layers)
// Process layers in reverse order (top layers first)
for i := len(manifest.Layers) - 1; i >= 0; i-- {
layer := manifest.Layers[i]
layerPath := filepath.Join(ociDir, "blobs", strings.Replace(layer.Digest, ":", "/", 1))
// Try to extract and find algorithm file
algoPath, seenFiles, err := extractLayerAndFindAlgorithm(logger, layerPath, destPath)
algoP, reqP, seenFiles, err := extractLayerAndFindAlgorithm(logger, layerPath, destPath, algoType)
if len(seenFiles) > 0 {
allSeenFiles = append(allSeenFiles, seenFiles...)
}
if err != nil {
logger.Warn(fmt.Sprintf("error extracting layer %s: %v", layer.Digest, err))
logger.Warn("failed to extract layer", "digest", layer.Digest, "error", err)
continue
}
if algoPath != "" {
return algoPath, nil
if algoP != "" && algorithmPath == "" {
algorithmPath = algoP
}
if reqP != "" && requirementsPath == "" {
requirementsPath = reqP
}
// If we found both, we can stop
if algorithmPath != "" && (algoType != "python" || requirementsPath != "") {
break
}
}
return "", fmt.Errorf("no algorithm file found in OCI image layers (seen: %v)", allSeenFiles)
if algorithmPath == "" {
return "", "", fmt.Errorf("no algorithm file found. Seen files: %v", allSeenFiles)
}
return algorithmPath, requirementsPath, nil
}
// extractLayerAndFindAlgorithm extracts a layer and searches for algorithm files.
func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath string) (string, []string, error) {
func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath, algoType string) (string, string, []string, error) {
// Open layer file
layerFile, err := os.Open(layerPath)
if err != nil {
return "", nil, fmt.Errorf("failed to open layer: %w", err)
return "", "", nil, fmt.Errorf("failed to open layer: %w", err)
}
defer layerFile.Close()
// Decompress gzip
gzReader, err := gzip.NewReader(layerFile)
if err != nil {
return "", nil, fmt.Errorf("failed to create gzip reader: %w", err)
return "", "", nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzReader.Close()
@@ -116,7 +130,8 @@ func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath strin
tarReader := tar.NewReader(gzReader)
var algorithmPath string
var seenFiles []string
var requirementsPath string
seenFiles := []string{}
for {
header, err := tarReader.Next()
@@ -124,7 +139,7 @@ func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath strin
break
}
if err != nil {
return "", seenFiles, fmt.Errorf("failed to read tar header: %w", err)
return "", "", seenFiles, fmt.Errorf("failed to read tar header: %w", err)
}
logger.Debug("inspecting file in layer", "name", header.Name, "type", header.Typeflag)
@@ -137,7 +152,7 @@ func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath strin
seenFiles = append(seenFiles, header.Name)
// Check if this is an algorithm file or requirements.txt
isAlgo := isAlgorithmFile(header.Name)
isAlgo := isAlgorithmFile(header.Name, header.Mode, algoType)
isReq := filepath.Base(header.Name) == "requirements.txt"
if isAlgo || isReq {
@@ -151,56 +166,69 @@ func extractLayerAndFindAlgorithm(logger *slog.Logger, layerPath, destPath strin
targetPath := filepath.Join(destPath, cleanName)
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return "", seenFiles, fmt.Errorf("failed to create dir: %w", err)
return "", "", seenFiles, fmt.Errorf("failed to create dir: %w", err)
}
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return "", seenFiles, fmt.Errorf("failed to create file: %w", err)
return "", "", seenFiles, fmt.Errorf("failed to create file: %w", err)
}
if _, err := io.Copy(outFile, tarReader); err != nil {
outFile.Close()
return "", seenFiles, fmt.Errorf("failed to write file: %w", err)
return "", "", seenFiles, fmt.Errorf("failed to write file: %w", err)
}
outFile.Close()
if isAlgo {
if isAlgo && algorithmPath == "" {
algorithmPath = targetPath
}
// Continue scanning to extract other files (like requirements.txt)
if isReq && requirementsPath == "" {
requirementsPath = targetPath
}
}
}
return algorithmPath, seenFiles, nil
return algorithmPath, requirementsPath, seenFiles, nil
}
// isAlgorithmFile checks if a file is likely an algorithm file.
func isAlgorithmFile(filename string) bool {
// Common algorithm file extensions
algorithmExts := []string{".py", ".wasm", ".wat", ".js", ".sh"}
// isAlgorithmFile checks if a file is likely an algorithm file based on its name, mode and expected algorithm type.
func isAlgorithmFile(filename string, mode int64, algoType string) bool {
base := filepath.Base(filename)
baseLower := strings.ToLower(base)
// Common algorithm file names
algorithmNames := []string{"algorithm", "main", "run", "execute"}
base := filepath.Base(filename)
baseLower := strings.ToLower(base)
// Check extensions
for _, ext := range algorithmExts {
if strings.HasSuffix(baseLower, ext) {
return true
switch algoType {
case "python":
return strings.HasSuffix(baseLower, ".py")
case "wasm":
return strings.HasSuffix(baseLower, ".wasm") || strings.HasSuffix(baseLower, ".wat")
case "bin":
// Ensure it doesn't have a known non-binary extension
nonBinExts := []string{".py", ".wasm", ".wat", ".js", ".sh", ".csv", ".json", ".txt", ".md"}
for _, ext := range nonBinExts {
if strings.HasSuffix(baseLower, ext) {
return false
}
}
}
// Check common names
for _, name := range algorithmNames {
if strings.Contains(baseLower, name) {
return true
// Check for common names
for _, name := range algorithmNames {
if strings.Contains(baseLower, name) {
return true
}
}
// Check if it's executable (at least one 'x' bit set)
return mode&0o111 != 0
case "docker":
// Docker algorithms are the whole image, this function shouldn't be used for them
return false
default:
// Unknown or empty algoType - no generic fallback to ensure explicit type usage
return false
}
return false
}
// ExtractDataset extracts dataset files from an OCI image directory.
@@ -328,7 +356,7 @@ func extractLayerDataFiles(layerPath, destPath string) ([]string, error) {
// isDataFile checks if a file is likely a dataset file.
func isDataFile(filename string) bool {
dataExts := []string{".csv", ".json", ".txt", ".parquet", ".arrow", ".dat"}
dataExts := []string{".csv", ".json", ".txt", ".parquet", ".arrow", ".dat", ".zip", ".tar", ".gz", ".tgz", ".tar.gz"}
baseLower := strings.ToLower(filepath.Base(filename))
+246 -29
View File
@@ -24,28 +24,28 @@ func TestIsAlgorithmFile(t *testing.T) {
tests := []struct {
name string
filename string
mode int64
algoType string
want bool
}{
{"Python file", "algorithm.py", true},
{"WASM file", "module.wasm", true},
{"WAT file", "module.wat", true},
{"JavaScript file", "script.js", true},
{"Shell script", "run.sh", true},
{"Main python file", "main.py", true},
{"Execute file", "execute.py", true},
{"Algorithm name in path", "src/algorithm_v2.py", true},
{"Random python file", "helper.py", true},
{"CSV data file", "data.csv", false},
{"JSON config file", "config.json", false},
{"Text file", "readme.txt", false},
{"Binary file", "data.bin", false},
{"Uppercase extension", "MAIN.PY", true},
{"Mixed case", "Algorithm.Py", true},
{"Python file", "algorithm.py", 0o644, "python", true},
{"WASM file", "module.wasm", 0o644, "wasm", true},
{"WAT file", "module.wat", 0o644, "wasm", true},
{"Python file as bin", "algorithm.py", 0o755, "bin", false},
{"Main python file", "main.py", 0o644, "python", true},
{"Binary file with common name", "algorithm", 0o644, "bin", true},
{"Binary file with common name run", "run", 0o644, "bin", true},
{"Executable binary", "my-app", 0o755, "bin", true},
{"CSV data file", "data.csv", 0o755, "python", false},
{"JSON config file", "config.json", 0o755, "wasm", false},
{"Text file", "readme.txt", 0o755, "bin", false},
{"Uppercase extension", "MAIN.PY", 0o644, "python", true},
{"Mixed case", "Algorithm.Py", 0o644, "python", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isAlgorithmFile(tt.filename)
got := isAlgorithmFile(tt.filename, tt.mode, tt.algoType)
assert.Equal(t, tt.want, got)
})
}
@@ -83,7 +83,7 @@ func TestExtractAlgorithm(t *testing.T) {
t.Run("missing index.json", func(t *testing.T) {
tempDir := t.TempDir()
_, err := ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
_, _, err := ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir(), "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read index.json")
})
@@ -93,7 +93,7 @@ func TestExtractAlgorithm(t *testing.T) {
err := os.WriteFile(filepath.Join(tempDir, "index.json"), []byte("not json"), 0o644)
require.NoError(t, err)
_, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
_, _, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir(), "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse index.json")
})
@@ -105,14 +105,14 @@ func TestExtractAlgorithm(t *testing.T) {
err := os.WriteFile(filepath.Join(tempDir, "index.json"), data, 0o644)
require.NoError(t, err)
_, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir())
_, _, err = ExtractAlgorithm(context.Background(), logger, tempDir, t.TempDir(), "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no manifests found")
})
t.Run("successful extraction", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "algorithm.py", testPythonScript)
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
algoPath, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
require.NoError(t, err)
assert.NotEmpty(t, algoPath)
assert.Contains(t, algoPath, "algorithm.py")
@@ -471,7 +471,7 @@ func TestExtractAlgorithmWithRequirements(t *testing.T) {
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
algoPath, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
require.NoError(t, err)
assert.Contains(t, algoPath, "main.py")
@@ -537,7 +537,7 @@ func TestExtractAlgorithmNoAlgoFile(t *testing.T) {
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
_, _, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
@@ -600,6 +600,41 @@ func TestExtractDatasetNoDataFiles(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "no dataset files found")
})
t.Run("corrupt layer file", func(t *testing.T) {
ociDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "layer123"), []byte("not a gzip"), 0o644))
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
// ExtractDataset logs a warning and continues if a layer fails, but if ALL fail it errors
_, err := ExtractDataset(ociDir, t.TempDir())
assert.Error(t, err)
})
}
func TestExtractAlgorithmInvalidManifest(t *testing.T) {
@@ -626,7 +661,7 @@ func TestExtractAlgorithmInvalidManifest(t *testing.T) {
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse manifest")
})
@@ -654,7 +689,7 @@ func TestExtractAlgorithmMissingManifest(t *testing.T) {
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read manifest")
})
@@ -723,7 +758,7 @@ func TestExtractAlgorithmWithDirectory(t *testing.T) {
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
algoPath, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
require.NoError(t, err)
assert.Contains(t, algoPath, "main.py")
})
@@ -795,7 +830,7 @@ func TestExtractAlgorithmPathTraversal(t *testing.T) {
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
algoPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
algoPath, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
require.NoError(t, err)
assert.Contains(t, algoPath, "algorithm.py")
@@ -815,7 +850,7 @@ func TestExtractAlgorithmErrorPathsAdditional(t *testing.T) {
err := os.WriteFile(layerPath, []byte("not gzip"), 0o644)
require.NoError(t, err)
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
_, _, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
@@ -833,7 +868,7 @@ func TestExtractAlgorithmErrorPathsAdditional(t *testing.T) {
err = os.WriteFile(layerPath, buf.Bytes(), 0o644)
require.NoError(t, err)
_, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
_, _, err = ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
@@ -867,7 +902,7 @@ func TestExtractAlgorithmErrorPathsAdditional(t *testing.T) {
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir)
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
@@ -918,3 +953,185 @@ func TestExtractDatasetErrorPathsAdditional(t *testing.T) {
assert.Contains(t, err.Error(), "no dataset files found")
})
}
func TestExtractAlgorithmAdditionalTypes(t *testing.T) {
t.Run("isAlgorithmFile additional types", func(t *testing.T) {
assert.False(t, isAlgorithmFile("any", 0o644, "docker"))
assert.False(t, isAlgorithmFile("any", 0o644, "unknown"))
})
}
func TestExtractAlgorithmErrorPathsInternal(t *testing.T) {
logger := slog.Default()
t.Run("failed to create directory", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "algorithm.py", "print('hello')")
// Create a file where a directory should be
blockedDir := filepath.Join(destDir, "blocked")
require.NoError(t, os.WriteFile(blockedDir, []byte("data"), 0o644))
// Try to extract an algorithm that would need to create a directory where a file exists
layerPath := filepath.Join(ociDir, "blobs", "sha256", "layer123")
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
tw := tar.NewWriter(gw)
hdr := &tar.Header{
Name: "blocked/main.py",
Mode: 0o644,
Size: int64(len("print(1)")),
}
require.NoError(t, tw.WriteHeader(hdr))
_, _ = tw.Write([]byte("print(1)"))
tw.Close()
gw.Close()
require.NoError(t, os.WriteFile(layerPath, buf.Bytes(), 0o644))
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
assert.Error(t, err)
})
t.Run("failed to create file", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "algorithm.py", "print('hello')")
// Create a directory where a file should be
blockedFile := filepath.Join(destDir, "algorithm.py")
require.NoError(t, os.MkdirAll(blockedFile, 0o755))
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
assert.Error(t, err)
})
}
func TestExtractDatasetErrorPathsInternal(t *testing.T) {
t.Run("failed to create directory for dataset", func(t *testing.T) {
ociDir, destDir := setupTestOCIImage(t, "data.csv", "a,b,c")
blockedDir := filepath.Join(destDir, "blocked")
require.NoError(t, os.WriteFile(blockedDir, []byte("data"), 0o644))
layerPath := filepath.Join(ociDir, "blobs", "sha256", "layer123")
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
tw := tar.NewWriter(gw)
hdr := &tar.Header{
Name: "blocked/data.csv",
Mode: 0o644,
Size: int64(len("a,b")),
}
require.NoError(t, tw.WriteHeader(hdr))
_, _ = tw.Write([]byte("a,b"))
tw.Close()
gw.Close()
require.NoError(t, os.WriteFile(layerPath, buf.Bytes(), 0o644))
_, err := ExtractDataset(ociDir, destDir)
assert.Error(t, err)
})
}
func TestExtractAlgorithm_PythonNoRequirements(t *testing.T) {
logger := slog.Default()
ociDir, destDir := setupTestOCIImage(t, "main.py", testPythonScript)
algoPath, reqPath, err := ExtractAlgorithm(context.Background(), logger, ociDir, destDir, "python")
require.NoError(t, err)
assert.NotEmpty(t, algoPath)
assert.Empty(t, reqPath)
}
func TestExtractDataset_MultipleLayers(t *testing.T) {
ociDir := t.TempDir()
destDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
createLayer := func(name, filename, content string) string {
path := filepath.Join(blobsDir, name)
f, err := os.Create(path)
require.NoError(t, err)
gw := gzip.NewWriter(f)
tw := tar.NewWriter(gw)
hdr := &tar.Header{Name: filename, Mode: 0o644, Size: int64(len(content))}
err = tw.WriteHeader(hdr)
require.NoError(t, err)
_, err = tw.Write([]byte(content))
require.NoError(t, err)
err = tw.Close()
require.NoError(t, err)
err = gw.Close()
require.NoError(t, err)
err = f.Close()
require.NoError(t, err)
return "sha256:" + name
}
layer1 := createLayer("l1", "data1.csv", "1,2")
layer2 := createLayer("l2", "data2.csv", "3,4")
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: layer1}, {Digest: layer2}},
}
manifestData, err := json.Marshal(manifest)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "m1"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:m1", Size: len(manifestData)}},
}
indexData, err := json.Marshal(index)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
files, err := ExtractDataset(ociDir, destDir)
require.NoError(t, err)
assert.Len(t, files, 2)
}
func TestExtractAlgorithm_ErrorPaths(t *testing.T) {
logger := slog.Default()
t.Run("invalid layer gzip", func(t *testing.T) {
ociDir := t.TempDir()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "l1"), []byte("not gzip"), 0o644))
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:l1"}},
}
manifestData, _ := json.Marshal(manifest)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "m1"), manifestData, 0o644))
index := OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:m1", Size: len(manifestData)}},
}
indexData, _ := json.Marshal(index)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
_, _, err := ExtractAlgorithm(context.Background(), logger, ociDir, t.TempDir(), "bin")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no algorithm file found")
})
}
+17
View File
@@ -109,6 +109,23 @@ func (s *SkopeoClient) Inspect(ctx context.Context, imageRef string) (*ImageMani
}, nil
}
// ToDockerArchive converts an OCI directory to a Docker archive tarball.
func (s *SkopeoClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error {
args := []string{"copy", "--insecure-policy", "--src-tls-verify=false", "--dest-tls-verify=false", "oci:" + ociDir, "docker-archive:" + destFile}
cmd := exec.CommandContext(ctx, s.skopeoPath, args...)
cmd.Env = append(os.Environ(),
OCICryptKeyproviderConfig+"="+DefaultOCICryptConfig)
cmd.Dir = s.workDir
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("skopeo copy to docker-archive failed: %w\nOutput: %s", err, string(output))
}
return nil
}
// GetLocalImagePath returns the path to a local OCI image directory.
func (s *SkopeoClient) GetLocalImagePath(name string) string {
return filepath.Join(s.workDir, name)
+16
View File
@@ -183,3 +183,19 @@ func TestSkopeoClientPullAndDecryptEncrypted(t *testing.T) {
assert.Contains(t, err.Error(), "skopeo copy failed")
})
}
func TestSkopeoClient_ToDockerArchive(t *testing.T) {
workDir := t.TempDir()
client, err := NewSkopeoClient(workDir)
if err != nil {
t.Skip("skopeo not installed, skipping test")
}
t.Run("invalid oci directory", func(t *testing.T) {
ctx := context.Background()
destFile := filepath.Join(t.TempDir(), "archive.tar")
err := client.ToDockerArchive(ctx, "/non/existent/oci/dir", destFile)
assert.Error(t, err)
assert.Contains(t, err.Error(), "skopeo copy to docker-archive failed")
})
}
BIN
View File
Binary file not shown.
+11 -5
View File
@@ -40,6 +40,11 @@ The service is configured using environment variables from the following table.
| `-algo-kbs-path` | Algorithm KBS resource path (e.g., 'default/key/algo-key') |
| `-dataset-source-urls` | Comma-separated dataset source URLs |
| `-dataset-kbs-paths` | Comma-separated dataset KBS resource paths |
| `-algo-type` | Algorithm execution type (binary, python, docker, etc.) |
| `-algo-args` | Comma-separated algorithm arguments |
| `-algo-hash` | Expected SHA3-256 hash of decrypted algorithm (hex) |
| `-dataset-hash` | Expected SHA3-256 hash of decrypted dataset (hex) |
| `-dataset-decompress` | Whether to decompress datasets (true,false) |
### Optional Flags
@@ -114,11 +119,12 @@ go run ./test/cvms/main.go \
## Notes
- **Either** `-algo-path` **OR** (`-algo-source-url` AND `-algo-kbs-path`) must be provided
- When using remote datasets, `-dataset-source-urls` and `-dataset-kbs-paths` must have the same number of comma-separated values
- The `-kbs-url` flag should be provided when using any remote resources
- For remote resources, the hash values in the manifest are currently placeholders (all zeros). In production, these should be the actual hashes of the **decrypted** data
- See [TESTING_REMOTE_RESOURCES.md](../TESTING_REMOTE_RESOURCES.md) for a complete guide on testing remote resource downloads with KBS attestation
- **Either** `-algo-path` **OR** (`-algo-source-url` AND `-algo-kbs-path`) must be provided.
- When using remote datasets, `-dataset-source-urls` and `-dataset-kbs-paths` must have the same number of comma-separated values.
- The `-kbs-url` flag should be provided when using any remote resources.
- **Checksum Verification**: For remote resources, you must provide the actual SHA3-256 hash of the **decrypted plaintext** content via `-algo-hash` and `-dataset-hash`. The Agent will verify this hash after downloading and decrypting the resource.
- **Calculating Hashes**: Use `cocos-cli checksum <path>` on your local source files (or directories) to generate the correct hash for the manifest.
- See [TESTING_REMOTE_RESOURCES.md](../../agent/TESTING_REMOTE_RESOURCES.md) for a complete guide on testing remote resource downloads with KBS attestation.
## Architecture
+25 -9
View File
@@ -140,13 +140,15 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
s.logger.Error(fmt.Sprintf("data file does not exist: %s", dataPath))
return
}
dataHash, err := internal.Checksum(dataPath)
dataHash, err := internal.ChecksumHex(dataPath)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
return
}
s.logger.Info("local dataset checksum", "path", dataPath, "hash", dataHash)
datasets = append(datasets, &cvms.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes})
hashBytes, _ := hex.DecodeString(dataHash)
datasets = append(datasets, &cvms.Dataset{Hash: hashBytes, UserKey: pubPem.Bytes})
}
}
@@ -170,11 +172,16 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
algoHashBytes = make([]byte, 32)
}
var algoArgs []string
if algoArgsString != "" {
algoArgs = strings.Split(algoArgsString, ",")
}
algorithm = &cvms.Algorithm{
Hash: algoHashBytes,
UserKey: pubPem.Bytes,
AlgoType: algoType,
AlgoArgs: strings.Split(algoArgsString, ","),
AlgoArgs: algoArgs,
Source: &cvms.Source{
Type: "oci-image",
Url: algoSourceURL,
@@ -184,16 +191,25 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
}
} else {
// Direct upload mode - use local file
if algoPath == "" {
s.logger.Error("algorithm path is required when not using remote source")
return
}
algoHash, err := internal.Checksum(algoPath)
fileHash, err := internal.ChecksumHex(algoPath)
if err != nil {
s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err))
return
}
algorithm = &cvms.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes}
s.logger.Info("local algorithm checksum", "path", algoPath, "hash", fileHash)
var algoArgs []string
if algoArgsString != "" {
algoArgs = strings.Split(algoArgsString, ",")
}
hashBytes, _ := hex.DecodeString(fileHash)
algorithm = &cvms.Algorithm{
Hash: hashBytes,
UserKey: pubPem.Bytes,
AlgoType: algoType,
AlgoArgs: algoArgs,
}
}
// Build KBS config