mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
Compare commits
120 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6169766666 | |||
| 5f339d2fab | |||
| 7e8eab77e7 | |||
| 9f31e2472b | |||
| e8e616ff62 | |||
| 0dce9d3083 | |||
| a37121dc7b | |||
| 1f0eccfae7 | |||
| 02aa7d7d85 | |||
| 27db9b29eb | |||
| 81fe0b11b5 | |||
| d5badba547 | |||
| c59a413765 | |||
| 3b9841a973 | |||
| b44780df95 | |||
| 80bf813c48 | |||
| 42b05524c8 | |||
| c1cbcec851 | |||
| da31d76c94 | |||
| f77ec5644a | |||
| 207bfd99af | |||
| de50b6d2d4 | |||
| a3265bc346 | |||
| ee52551ca4 | |||
| 5ae4f0f401 | |||
| 0a850b6bab | |||
| a69dbda46b | |||
| dde4249abc | |||
| 97ee07979e | |||
| 48310fb9e6 | |||
| a128895ede | |||
| 9d900d40f6 | |||
| 5a4ac9d720 | |||
| fdcde2b9aa | |||
| 3498db14fb | |||
| c422afe0a6 | |||
| 3f06971976 | |||
| 9d8bb90476 | |||
| e634b67bc5 | |||
| 291755ec87 | |||
| de8e198b71 | |||
| 3b1605da77 | |||
| 77a11c6535 | |||
| 364724ff1b | |||
| e382664a6a | |||
| fd84a37eca | |||
| cf32a252de | |||
| 2b38f4595c | |||
| 04b0cdfd5d | |||
| 6b26f40a72 | |||
| 439b041086 | |||
| 1143d4cc19 | |||
| bd92b96b63 | |||
| 93ac30d1a9 | |||
| 817ac6c35c | |||
| 6811a2481b | |||
| 0ffc2d17cf | |||
| 0be724386b | |||
| 7e59ca09fc | |||
| 3aed6df66e | |||
| fc5eff9ff0 | |||
| 622f499a76 | |||
| 5783055e67 | |||
| c758b3b216 | |||
| 906d7877b2 | |||
| 5377dd4d7f | |||
| 1e2e635e69 | |||
| 541368844d | |||
| 09832e48c9 | |||
| b5daee9e74 | |||
| e42d24b536 | |||
| 24998341d9 | |||
| c0efb49ac3 | |||
| a9074e535f | |||
| 25d6b088e7 | |||
| a6cd29d2c8 | |||
| 4b27b98edb | |||
| 654e22bba5 | |||
| 3cec8e2076 | |||
| 3e02cde7a2 | |||
| ccab296b62 | |||
| be423e0231 | |||
| 92ba15d2de | |||
| 49a66d6f35 | |||
| 8eb1fac9ad | |||
| 4b657e5313 | |||
| 38c2abb294 | |||
| 4e8057f481 | |||
| 85a2b7a6c8 | |||
| 45187d7f41 | |||
| f543cb4363 | |||
| cef47baed7 | |||
| 698bd948ed | |||
| 9c8ddfd2b1 | |||
| 79c66a89c3 | |||
| f52702b631 | |||
| 31c7833c3d | |||
| 64bf7a56ac | |||
| bd59a4a617 | |||
| c9af8a166b | |||
| 17c6accbff | |||
| 2d6d276061 | |||
| e8c2ccc071 | |||
| f1af397aa0 | |||
| 77325753f8 | |||
| 3e474338c5 | |||
| 5960b06126 | |||
| 636d3dcaa0 | |||
| 92f4f0535a | |||
| bf84f45306 | |||
| f5b67ca35b | |||
| 3bb0b2a315 | |||
| 434d58f890 | |||
| 32e2bfb881 | |||
| 26bf5dc643 | |||
| bda3968fdf | |||
| 90807d9576 | |||
| 94c169febb | |||
| 3102114ff3 | |||
| 5c60bc2a48 |
@@ -1,15 +1,5 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "cargo"
|
||||
directory: "/scripts/attestation_policy"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "monday"
|
||||
groups:
|
||||
rs-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
- package-ecosystem: "gomod"
|
||||
directories:
|
||||
- "/"
|
||||
|
||||
@@ -9,6 +9,7 @@ on:
|
||||
- "pkg/manager/*.pb.go"
|
||||
- "agent/agent.proto"
|
||||
- "agent/*.pb.go"
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
@@ -29,13 +30,13 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23.x
|
||||
go-version: 1.26.x
|
||||
|
||||
- name: Set up protoc
|
||||
run: |
|
||||
PROTOC_VERSION=29.0
|
||||
PROTOC_GEN_VERSION=v1.36.5
|
||||
PROTOC_GRPC_VERSION=v1.5.1
|
||||
PROTOC_VERSION=33.1
|
||||
PROTOC_GEN_VERSION=v1.36.11
|
||||
PROTOC_GRPC_VERSION=v1.6.0
|
||||
|
||||
# Download and install protoc
|
||||
PROTOC_ZIP=protoc-$PROTOC_VERSION-linux-x86_64.zip
|
||||
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23.x
|
||||
go-version: 1.26.x
|
||||
cache-dependency-path: "go.sum"
|
||||
|
||||
- name: Checkout cocos
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
with:
|
||||
repository: "buildroot/buildroot"
|
||||
path: buildroot
|
||||
ref: 2024.11-rc2
|
||||
ref: 2025.08-rc3
|
||||
|
||||
- name: Build hal
|
||||
run: |
|
||||
|
||||
+50
-24
@@ -9,9 +9,8 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -19,39 +18,66 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23.x
|
||||
go-version: 1.26.x
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: v2.0.2
|
||||
version: v2.11.1
|
||||
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
make
|
||||
run: make
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: lint
|
||||
strategy:
|
||||
matrix:
|
||||
module: [agent, cli, cmd, internal, pkg, manager]
|
||||
include:
|
||||
- module: manager
|
||||
sudo: true
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.26.x
|
||||
|
||||
- name: Create coverage directory
|
||||
run: mkdir -p coverage
|
||||
|
||||
- name: Run tests for ${{ matrix.module }}
|
||||
run: |
|
||||
mkdir coverage
|
||||
if [[ "${{ matrix.module }}" == "manager" ]]; then
|
||||
sudo GOTOOLCHAIN=go1.26.0+auto go test -v --race -covermode=atomic -coverprofile coverage/${{ matrix.module }}.out ./${{ matrix.module }}/...
|
||||
else
|
||||
GOTOOLCHAIN=go1.26.0+auto go test -v --race -covermode=atomic -coverprofile coverage/${{ matrix.module }}.out ./${{ matrix.module }}/...
|
||||
fi
|
||||
|
||||
- name: Run Agent tests
|
||||
run: go test --tags embed -v --race -covermode=atomic -coverprofile coverage/agent.out ./agent/...
|
||||
- name: Upload coverage artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-${{ matrix.module }}
|
||||
path: coverage/${{ matrix.module }}.out
|
||||
retention-days: 1
|
||||
|
||||
- name: Run cli tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/cli.out ./cli/...
|
||||
upload-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run cmd tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/cmd.out ./cmd/...
|
||||
|
||||
- name: Run internal tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/internal.out ./internal/...
|
||||
|
||||
- name: Run pkg tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/pkg.out ./pkg/...
|
||||
|
||||
- name: Run manager tests
|
||||
run: sudo go test -v --race -covermode=atomic -coverprofile coverage/manager.out ./manager/...
|
||||
- name: Download all coverage artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: coverage-*
|
||||
path: coverage/
|
||||
merge-multiple: true
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
name: Rust CI Pipeline
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "scripts/attestation_policy/**"
|
||||
- ".github/workflows/rust.yaml"
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "scripts/attestation_policy/**"
|
||||
- ".github/workflows/rust.yaml"
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
rust-check:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./scripts/attestation_policy
|
||||
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check cargo
|
||||
run: cargo check --release --all-targets
|
||||
|
||||
- name: Check formatting
|
||||
run: cargo fmt --all -- --check
|
||||
|
||||
- name: Run linter
|
||||
run: cargo clippy -- -D warnings
|
||||
|
||||
- name: Build for all features
|
||||
run: cargo build --release --all-features
|
||||
@@ -19,9 +19,15 @@ target/
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
!tools/nvidia-attestation-helper/Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
*.enc
|
||||
*.key
|
||||
*.pub
|
||||
.codex
|
||||
|
||||
+4
-6
@@ -55,12 +55,6 @@ linters:
|
||||
template: |-
|
||||
Copyright (c) Ultraviolet
|
||||
SPDX-License-Identifier: Apache-2.0
|
||||
importas:
|
||||
alias:
|
||||
- pkg: github.com/absmach/magistrala/logger
|
||||
alias: mglog
|
||||
no-unaliased: true
|
||||
no-extra-aliases: false
|
||||
staticcheck:
|
||||
checks:
|
||||
- -ST1000
|
||||
@@ -76,10 +70,14 @@ linters:
|
||||
- legacy
|
||||
- std-error-handling
|
||||
rules:
|
||||
- linters:
|
||||
- errcheck
|
||||
path: build/
|
||||
- linters:
|
||||
- makezero
|
||||
text: with non-zero initialized length
|
||||
paths:
|
||||
- build
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
+162
@@ -0,0 +1,162 @@
|
||||
pkgname: mocks
|
||||
template: testify
|
||||
template-data:
|
||||
boilerplate-file: ./boilerplate.txt
|
||||
unroll-variadic: true
|
||||
packages:
|
||||
github.com/ultravioletrs/cocos/agent:
|
||||
interfaces:
|
||||
AgentService_AlgoClient:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
AgentService_DataClient:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
AgentService_IMAMeasurementsClient:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
Service:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/agent/algorithm:
|
||||
interfaces:
|
||||
Algorithm:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/agent/auth:
|
||||
interfaces:
|
||||
Authenticator:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage:
|
||||
interfaces:
|
||||
Storage:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/agent/cvms/server:
|
||||
interfaces:
|
||||
AgentServer:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/agent/events:
|
||||
interfaces:
|
||||
Service:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/agent/statemachine:
|
||||
interfaces:
|
||||
StateMachine:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/server:
|
||||
interfaces:
|
||||
Server:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/manager:
|
||||
interfaces:
|
||||
ManagerServiceClient:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
Service:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/manager/qemu:
|
||||
interfaces:
|
||||
Persistence:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/manager/vm:
|
||||
interfaces:
|
||||
StateMachine:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
VM:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/attestation:
|
||||
interfaces:
|
||||
Provider:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
Verifier:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig:
|
||||
interfaces:
|
||||
MeasurementProvider:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/clients/grpc:
|
||||
interfaces:
|
||||
Client:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/sdk:
|
||||
interfaces:
|
||||
SDK:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/atls:
|
||||
interfaces:
|
||||
CertificateProvider:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/clients/grpc/runner:
|
||||
interfaces:
|
||||
Client:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/internal/proto/attestation-agent:
|
||||
interfaces:
|
||||
AttestationAgentServiceClient:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
@@ -1,12 +1,24 @@
|
||||
BUILD_DIR = build
|
||||
SERVICES = manager agent cli
|
||||
ATTESTATION_POLICY = attestation_policy
|
||||
CGO_ENABLED ?= 1
|
||||
SERVICES = manager agent cli attestation-service log-forwarder computation-runner egress-proxy ingress-proxy
|
||||
NVIDIA_ATTESTATION_HELPER = nvidia-attestation-helper
|
||||
NVIDIA_ATTESTATION_HELPER_DIR = tools/$(NVIDIA_ATTESTATION_HELPER)
|
||||
NVIDIA_ATTESTATION_HELPER_MANIFEST = $(NVIDIA_ATTESTATION_HELPER_DIR)/Cargo.toml
|
||||
NVIDIA_ATTESTATION_HELPER_BINARY = $(BUILD_DIR)/$(NVIDIA_ATTESTATION_HELPER)
|
||||
NVIDIA_ATTESTATION_HELPER_LIB_DIR = $(BUILD_DIR)/lib
|
||||
NVAT_SDK_CPP_DIR ?= $(firstword $(wildcard $(HOME)/.cargo/git/checkouts/attestation-sdk-*/*/nv-attestation-sdk-cpp))
|
||||
NVAT_SDK_CPP_BUILD_DIR ?= $(NVAT_SDK_CPP_DIR)/build
|
||||
NVAT_SDK_HEADER ?= $(NVAT_SDK_CPP_BUILD_DIR)/include/nvat.h
|
||||
NVAT_SDK_SHARED_LIB ?= $(NVAT_SDK_CPP_BUILD_DIR)/libnvat.so.1
|
||||
NVAT_SYSTEM_HEADER ?= /usr/include/nvat.h
|
||||
CARGO ?= cargo
|
||||
CMAKE ?= cmake
|
||||
CGO_ENABLED ?= 0
|
||||
GOARCH ?= amd64
|
||||
VERSION ?= $(shell git describe --abbrev=0 --tags --always)
|
||||
COMMIT ?= $(shell git rev-parse HEAD)
|
||||
TIME ?= $(shell date +%F_%T)
|
||||
EMBED_ENABLED ?= 0
|
||||
NVAT_USE_SYSTEM_LIB ?=
|
||||
INSTALL_DIR ?= /usr/local/bin
|
||||
CONFIG_DIR ?= /etc/cocos
|
||||
SERVICE_NAME ?= cocos-manager
|
||||
@@ -21,28 +33,61 @@ define compile_service
|
||||
-X 'github.com/absmach/magistrala.Version=$(VERSION)' \
|
||||
-X 'github.com/absmach/magistrala.Commit=$(COMMIT)'" \
|
||||
$(if $(filter 1,$(EMBED_ENABLED)),-tags "embed",) \
|
||||
-o ${BUILD_DIR}/cocos-$(1) cmd/$(1)/main.go
|
||||
-o ${BUILD_DIR}/cocos-$(1) ./cmd/$(1)
|
||||
endef
|
||||
|
||||
.PHONY: all $(SERVICES) $(ATTESTATION_POLICY) install clean
|
||||
NVIDIA_ATTESTATION_HELPER_CARGO_ENV = $(if $(filter 1,$(NVAT_USE_SYSTEM_LIB)),NVAT_USE_SYSTEM_LIB=1,)
|
||||
NVIDIA_ATTESTATION_HELPER_RUSTFLAGS = $(strip $(RUSTFLAGS) $(if $(filter 1,$(NVAT_USE_SYSTEM_LIB)),,-C link-arg=-Wl,-rpath,$$ORIGIN/lib))
|
||||
|
||||
.PHONY: all $(SERVICES) $(NVIDIA_ATTESTATION_HELPER) nvidia-attestation-helper-prereqs install clean
|
||||
|
||||
all: $(SERVICES)
|
||||
|
||||
$(SERVICES):
|
||||
$(BUILD_DIR):
|
||||
mkdir -p $(BUILD_DIR)
|
||||
|
||||
$(SERVICES): | $(BUILD_DIR)
|
||||
$(call compile_service,$@)
|
||||
@if [ "$@" = "cli" ] || [ "$@" = "manager" ]; then $(MAKE) build-igvm; fi
|
||||
|
||||
$(ATTESTATION_POLICY):
|
||||
$(MAKE) -C ./scripts/attestation_policy
|
||||
nvidia-attestation-helper-prereqs:
|
||||
ifeq ($(filter 1,$(NVAT_USE_SYSTEM_LIB)),1)
|
||||
@test -f $(NVAT_SYSTEM_HEADER) || \
|
||||
( echo "Missing $(NVAT_SYSTEM_HEADER). Install the NVAT development package or run without NVAT_USE_SYSTEM_LIB=1."; exit 1 )
|
||||
@ldconfig -p | grep -q libnvat.so.1 || \
|
||||
( echo "libnvat.so.1 not found in the dynamic linker cache. Install the NVAT runtime package or run without NVAT_USE_SYSTEM_LIB=1."; exit 1 )
|
||||
else
|
||||
@if [ -z "$(NVAT_SDK_CPP_DIR)" ]; then \
|
||||
echo "Unable to locate nv-attestation-sdk-cpp under $$HOME/.cargo/git/checkouts."; \
|
||||
echo "Run 'cargo fetch --manifest-path $(NVIDIA_ATTESTATION_HELPER_MANIFEST)' first, or install NVAT and use 'make NVAT_USE_SYSTEM_LIB=1 $(NVIDIA_ATTESTATION_HELPER)'."; \
|
||||
exit 1; \
|
||||
fi
|
||||
@if [ ! -f "$(NVAT_SDK_HEADER)" ] || [ ! -f "$(NVAT_SDK_SHARED_LIB)" ]; then \
|
||||
$(CMAKE) -S $(NVAT_SDK_CPP_DIR) -B $(NVAT_SDK_CPP_BUILD_DIR) && \
|
||||
$(CMAKE) --build $(NVAT_SDK_CPP_BUILD_DIR); \
|
||||
fi
|
||||
endif
|
||||
|
||||
$(NVIDIA_ATTESTATION_HELPER): nvidia-attestation-helper-prereqs | $(BUILD_DIR)
|
||||
RUSTFLAGS='$(NVIDIA_ATTESTATION_HELPER_RUSTFLAGS)' $(NVIDIA_ATTESTATION_HELPER_CARGO_ENV) $(CARGO) build --manifest-path $(NVIDIA_ATTESTATION_HELPER_MANIFEST) --release
|
||||
install -m 755 $(NVIDIA_ATTESTATION_HELPER_DIR)/target/release/$(NVIDIA_ATTESTATION_HELPER) $(NVIDIA_ATTESTATION_HELPER_BINARY)
|
||||
@if [ "$(filter 1,$(NVAT_USE_SYSTEM_LIB))" != "1" ]; then \
|
||||
install -d $(NVIDIA_ATTESTATION_HELPER_LIB_DIR); \
|
||||
install -m 755 $(NVAT_SDK_SHARED_LIB) $(NVIDIA_ATTESTATION_HELPER_LIB_DIR)/libnvat.so.1; \
|
||||
fi
|
||||
|
||||
protoc:
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/agent.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative manager/manager.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/events/events.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/cvms/cvms.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative internal/proto/attestation/v1/attestation.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative internal/proto/attestation-agent/attestation-agent.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/log/log.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/runner/runner.proto
|
||||
|
||||
mocks:
|
||||
mockery --config ./mockery.yml
|
||||
mockery --config ./.mockery.yml
|
||||
|
||||
install: $(SERVICES)
|
||||
install -d $(INSTALL_DIR)
|
||||
|
||||
@@ -77,4 +77,4 @@ Cocos AI is published under the permissive open-source [Apache-2.0](LICENSE) lic
|
||||
- [Confidential Computing Overview](https://confidentialcomputing.io/white-papers-reports/)
|
||||
- [Trusted Execution Environments (TEEs)](https://en.wikipedia.org/wiki/Trusted_execution_environment)
|
||||
|
||||
>This work has been partially supported by the ELASTIC project, which received funding from the Smart Networks and Services Joint Undertaking (SNS JU) under the European Union’s Horizon Europe research and innovation programme under Grant Agreement No 101139067. Views and opinions expressed are however those of the author(s) only and do not necessarily reflect those of the European Union. Neither the European Union nor the granting authority can be held responsible for them.
|
||||
>This work has been partially supported by the [ELASTIC](https://elasticproject.eu/) and [CONFIDENTIAL6G](https://confidential6g.eu/), which received funding from the Smart Networks and Services Joint Undertaking (SNS JU) under the European Union’s Horizon Europe research and innovation programme under [Grant Agreement No. 101139067](https://cordis.europa.eu/project/id/101139067) and [Grant Agreement No. 101096435](https://cordis.europa.eu/project/id/101096435). Views and opinions expressed are however those of the author(s) only and do not necessarily reflect those of the European Union. Neither the European Union nor the granting authority can be held responsible for them.
|
||||
|
||||
+41
-10
@@ -6,16 +6,47 @@ Agent service provides a barebones HTTP and gRPC API and Service interface imple
|
||||
|
||||
The service is configured using the environment variables from the following table. Note that any unset variables will be replaced with their default values.
|
||||
|
||||
| Variable | Description | Default |
|
||||
| ------------------------------ | ------------------------------------------------------------------------------------------------------------- | ------------------------------ |
|
||||
| AGENT_LOG_LEVEL | Log level for agent service (debug, info, warn, error) | debug |
|
||||
| AGENT_CVM_GRPC_HOST | Agent service gRPC host | "" |
|
||||
| AGENT_CVM_GRPC_PORT | Agent service gRPC port | 7001 |
|
||||
| AGENT_CVM_GRPC_SERVER_CERT | Path to gRPC server certificate in pem format | "" |
|
||||
| AGENT_CVM_GRPC_SERVER_KEY | Path to gRPC server key in pem format | "" |
|
||||
| AGENT_CVM_GRPC_SERVER_CA_CERTS | Path to gRPC server CA certificate | "" |
|
||||
| AGENT_CVM_GRPC_CLIENT_CA_CERTS | Path to gRPC client CA certificate | "" |
|
||||
| AGENT_CVM_CA_URL | URL for CA service, if provided it will be used for certificate generation, used only with aTLS at the moment | "" |
|
||||
| Variable | Description | Default |
|
||||
| ------------------------------ | ------------------------------------------------------------------------------------------------------------- | ----------------------------------------------- |
|
||||
| AGENT_LOG_LEVEL | Log level for agent service (debug, info, warn, error) | debug |
|
||||
| AGENT_VMPL | VMPL (Virtual Machine Privilege Level) for AMD SEV-SNP attestation (0-3) | 2 |
|
||||
| AGENT_GRPC_HOST | Agent service gRPC host address | 0.0.0.0 |
|
||||
| AGENT_CVM_GRPC_HOST | Agent service gRPC host | "" |
|
||||
| AGENT_CVM_GRPC_PORT | Agent service gRPC port | 7001 |
|
||||
| AGENT_CVM_GRPC_SERVER_CERT | Path to gRPC server certificate in pem format | "" |
|
||||
| AGENT_CVM_GRPC_SERVER_KEY | Path to gRPC server key in pem format | "" |
|
||||
| AGENT_CVM_GRPC_SERVER_CA_CERTS | Path to gRPC server CA certificate | "" |
|
||||
| AGENT_CVM_GRPC_CLIENT_CA_CERTS | Path to gRPC client CA certificate | "" |
|
||||
| AGENT_CVM_CA_URL | URL for CA service, if provided it will be used for certificate generation, used only with aTLS at the moment | "" |
|
||||
| AGENT_CVM_ID | Unique identifier for the CVM (Confidential Virtual Machine) | "" |
|
||||
| AGENT_CERTS_TOKEN | Authentication token for certificate service access | "" |
|
||||
| AGENT_MAA_URL | Microsoft Azure Attestation service URL for Azure attestation | https://sharedeus2.eus2.attest.azure.net |
|
||||
| AZURE_TDX_IMDS_URL | Azure TDX quote endpoint used by direct Azure TDX attestation | http://169.254.169.254/acc/tdquote |
|
||||
| AZURE_HCL_REFRESH_WAIT | Wait after writing TDX report data to Azure HCL vTPM storage before reading the refreshed HCL report | 3s |
|
||||
| AGENT_OS_BUILD | Operating system build information for attestation | UVC |
|
||||
| AGENT_OS_DISTRO | Operating system distribution information for attestation | UVC |
|
||||
| AGENT_OS_TYPE | Operating system type information for attestation | UVC |
|
||||
| ATTESTATION_SERVICE_SOCKET | Unix socket path for attestation service communication | /run/cocos/attestation.sock |
|
||||
| AGENT_ENABLE_ATLS | Enable Attestation TLS for secure communication | true |
|
||||
|
||||
### Azure TDX Attestation
|
||||
|
||||
When the agent runs on an Azure TDX CVM, Azure attestation uses the direct Azure TDX flow. The agent writes TDX report data to Azure HCL vTPM storage, reads the refreshed HCL report, requests a TD quote from Azure IMDS, and submits the quote plus HCL runtime data to Microsoft Azure Attestation. This path does not depend on Confidential Containers attestation-agent `GetEvidence` or KBS token retrieval.
|
||||
|
||||
`AGENT_MAA_URL` selects the Microsoft Azure Attestation endpoint. `AZURE_TDX_IMDS_URL` can override the Azure IMDS TDX quote endpoint, and `AZURE_HCL_REFRESH_WAIT` controls the wait used to avoid reading a stale HCL report after report-data is written.
|
||||
|
||||
### Remote Resource Download (Optional)
|
||||
|
||||
The agent supports downloading encrypted algorithms and datasets from remote registries (S3, HTTP/HTTPS) and retrieving decryption keys from a Key Broker Service (KBS) via attestation.
|
||||
|
||||
| Variable | Description | Default |
|
||||
| ------------------------------ | ------------------------------------------------------------------------------------------------------------- | ----------------------------------------------- |
|
||||
| AWS_REGION | AWS region for S3 access (required for S3 downloads) | \"\" |
|
||||
| AWS_ACCESS_KEY_ID | AWS access key ID for S3 authentication | \"\" |
|
||||
| AWS_SECRET_ACCESS_KEY | AWS secret access key for S3 authentication | \"\" |
|
||||
| AWS_ENDPOINT_URL | Custom S3 endpoint URL (for S3-compatible services like MinIO) | \"\" |
|
||||
|
||||
**Note**: KBS URL is specified in the computation manifest, not as an environment variable. See [TESTING_REMOTE_RESOURCES.md](./TESTING_REMOTE_RESOURCES.md) for details on using remote resources.
|
||||
|
||||
## Deployment
|
||||
|
||||
|
||||
@@ -0,0 +1,468 @@
|
||||
# Testing Remote Resources with CoCo Key Provider
|
||||
|
||||
This guide explains how to test Cocos with encrypted remote resources using the Confidential Containers Key Provider ecosystem.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌────────────────────────────────────────────────────────────┐
|
||||
│ CVM (Agent) │
|
||||
│ │
|
||||
│ ┌──────────┐ ┌────────────────┐ ┌─────────────────┐ │
|
||||
│ │ Agent │──▶│ Skopeo │──▶│ CoCo Keyprovider│ │
|
||||
│ │ │ │ (ocicrypt) │ │ (gRPC:50011) │ │
|
||||
│ │ │ └───────┬────────┘ └────────┬────────┘ │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ ┌───────▼────────┐ ┌────────▼────────┐ │
|
||||
│ │ │──▶│ S3/HTTP │ │ Attestation │ │
|
||||
│ │ │ │ Downloader │ │ Agent (50002) │ │
|
||||
│ └────┬─────┘ └───────┬────────┘ └────────┬────────┘ │
|
||||
│ │ │ │ │
|
||||
│ └──────────────────┼──────────────────────┘ │
|
||||
└────────┬─────────────────┼──────────────────────┬──────────┘
|
||||
│ (Resource) │ (Resource) │ (Attest)
|
||||
▼ ▼ ▼
|
||||
OCI Registry S3 / HTTP / GCS KBS
|
||||
(Key Broker)
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Install Skopeo (Host Machine)
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install skopeo
|
||||
|
||||
# macOS
|
||||
brew install skopeo
|
||||
|
||||
# Or build from source
|
||||
git clone https://github.com/containers/skopeo
|
||||
cd skopeo
|
||||
make bin/skopeo
|
||||
sudo make install
|
||||
```
|
||||
|
||||
### 2. Start KBS Server (Host Machine)
|
||||
|
||||
```bash
|
||||
# Clone and build KBS
|
||||
git clone https://github.com/confidential-containers/trustee
|
||||
cd trustee/kbs
|
||||
# Patch Cargo.toml to disable SGX requirement (for testing only)
|
||||
sed -i 's/"all-verifier",//g' Cargo.toml
|
||||
|
||||
make
|
||||
make cli
|
||||
|
||||
# Generate admin keys
|
||||
openssl genpkey -algorithm ed25519 -out kbs-admin.key
|
||||
openssl pkey -in kbs-admin.key -pubout -out kbs-admin.pub
|
||||
|
||||
# Create KBS configuration file
|
||||
cat > kbs-config.toml << 'EOF'
|
||||
[http_server]
|
||||
sockets = ["0.0.0.0:8080"]
|
||||
insecure_http = true
|
||||
|
||||
[admin]
|
||||
type = "Simple"
|
||||
[[admin.personas]]
|
||||
id = "admin"
|
||||
public_key_path = "kbs-admin.pub"
|
||||
|
||||
[attestation_service]
|
||||
type = "coco_as_builtin"
|
||||
work_dir = "kbs-data/as"
|
||||
|
||||
[attestation_service.rvps_config]
|
||||
type = "BuiltIn"
|
||||
|
||||
[attestation_service.rvps_config.storage]
|
||||
type = "LocalFs"
|
||||
file_path = "kbs-data/rvps-values"
|
||||
|
||||
[[plugins]]
|
||||
name = "resource"
|
||||
type = "LocalFs"
|
||||
dir_path = "kbs-data/repository"
|
||||
EOF
|
||||
|
||||
# Create configuration directories
|
||||
mkdir -p kbs-data/as kbs-data/rvps kbs-data/repository
|
||||
|
||||
# Start KBS
|
||||
sudo ../target/release/kbs --config-file kbs-config.toml
|
||||
```
|
||||
|
||||
KBS will listen on `http://localhost:8080`
|
||||
|
||||
### 3. Setup Local OCI Registry (Optional)
|
||||
|
||||
For testing, you can use a local registry:
|
||||
|
||||
```bash
|
||||
docker run -d -p 5000:5000 --name registry registry:2
|
||||
```
|
||||
|
||||
## Creating Encrypted Resources
|
||||
|
||||
### Encrypt an Algorithm (Python Script)
|
||||
|
||||
```bash
|
||||
# 1. Create a simple algorithm
|
||||
cat > lin_reg.py << 'EOF'
|
||||
import pandas as pd
|
||||
from sklearn.linear_model import LinearRegression
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Load dataset
|
||||
data = pd.read_csv(sys.argv[1])
|
||||
X = data[['feature1', 'feature2']]
|
||||
y = data['target']
|
||||
|
||||
# Train model
|
||||
model = LinearRegression()
|
||||
model.fit(X, y)
|
||||
|
||||
# Save results
|
||||
os.makedirs("results", exist_ok=True)
|
||||
with open("results/output.txt", "w") as f:
|
||||
f.write(f"Coefficients: {model.coef_}\n")
|
||||
f.write(f"Intercept: {model.intercept_}\n")
|
||||
|
||||
print(f"Coefficients: {model.coef_}")
|
||||
print(f"Intercept: {model.intercept_}")
|
||||
EOF
|
||||
|
||||
# 2. Create requirements.txt
|
||||
cat > requirements.txt << 'EOF'
|
||||
pandas
|
||||
scikit-learn
|
||||
EOF
|
||||
|
||||
# 3. Create a Dockerfile
|
||||
cat > Dockerfile << 'EOF'
|
||||
FROM python:3.9-slim
|
||||
RUN pip install pandas scikit-learn
|
||||
COPY lin_reg.py /app/algorithm.py
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
WORKDIR /app
|
||||
ENTRYPOINT ["python", "algorithm.py"]
|
||||
EOF
|
||||
|
||||
# 4. Build the image
|
||||
docker build -t localhost:5000/lin-reg-algo:v1.0 .
|
||||
docker push localhost:5000/lin-reg-algo:v1.0
|
||||
|
||||
# 5. Generate and store key
|
||||
openssl rand -out algo.key 32
|
||||
|
||||
# 6. Store key in KBS using kbs-client
|
||||
../target/release/kbs-client --url http://localhost:8080 config \
|
||||
--auth-private-key kbs-admin.key \
|
||||
set-resource \
|
||||
--path default/key/algo-key \
|
||||
--resource-file algo.key
|
||||
|
||||
# 7. Encrypt the image using Host Skopeo + Docker Keyprovider
|
||||
# Start Keyprovider in background
|
||||
docker run -d --rm --name keyprovider --network host \
|
||||
-v "$PWD:/work" -w /work \
|
||||
ghcr.io/confidential-containers/staged-images/coco-keyprovider:latest \
|
||||
coco_keyprovider --socket 127.0.0.1:50000
|
||||
|
||||
# Configure Ocicrypt to use local Keyprovider
|
||||
cat <<EOF > ocicrypt.conf
|
||||
{
|
||||
"key-providers": {
|
||||
"attestation-agent": {
|
||||
"grpc": "127.0.0.1:50000"
|
||||
}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
export OCICRYPT_KEYPROVIDER_CONFIG=$(pwd)/ocicrypt.conf
|
||||
|
||||
# Encrypt Algo
|
||||
skopeo copy \
|
||||
--src-tls-verify=false \
|
||||
--dest-tls-verify=false \
|
||||
--encryption-key "provider:attestation-agent:keypath=/work/algo.key::keyid=kbs:///default/key/algo-key::algorithm=A256GCM" \
|
||||
docker://localhost:5000/lin-reg-algo:v1.0 \
|
||||
docker://localhost:5000/encrypted-lin-reg:v1.0
|
||||
|
||||
# Stop Keyprovider
|
||||
docker stop keyprovider
|
||||
```
|
||||
|
||||
### Encrypt a Dataset (CSV in OCI Image)
|
||||
|
||||
```bash
|
||||
# 1. Create dataset
|
||||
cat > iris.csv << 'EOF'
|
||||
feature1,feature2,target
|
||||
5.1,3.5,0
|
||||
4.9,3.0,0
|
||||
6.2,3.4,1
|
||||
5.9,3.0,1
|
||||
EOF
|
||||
|
||||
# 2. Create Dockerfile for dataset
|
||||
cat > Dockerfile.dataset << 'EOF'
|
||||
FROM scratch
|
||||
COPY iris.csv /data/iris.csv
|
||||
EOF
|
||||
|
||||
# 3. Build and push
|
||||
docker build -f Dockerfile.dataset -t localhost:5000/iris-dataset:v1.0 .
|
||||
docker push localhost:5000/iris-dataset:v1.0
|
||||
|
||||
# 4. Generate and store key
|
||||
# 4. Generate and store key
|
||||
openssl rand -out dataset.key 32
|
||||
../target/release/kbs-client --url http://localhost:8080 config \
|
||||
--auth-private-key kbs-admin.key \
|
||||
set-resource \
|
||||
--path default/key/dataset-key \
|
||||
--resource-file dataset.key
|
||||
|
||||
# 5. Encrypt dataset image using Host Skopeo + Docker Keyprovider
|
||||
# Start Keyprovider in background
|
||||
docker run -d --rm --name keyprovider --network host \
|
||||
-v "$PWD:/work" -w /work \
|
||||
ghcr.io/confidential-containers/staged-images/coco-keyprovider:latest \
|
||||
coco_keyprovider --socket 127.0.0.1:50000
|
||||
|
||||
# Configure Ocicrypt (if not already done)
|
||||
export OCICRYPT_KEYPROVIDER_CONFIG=$(pwd)/ocicrypt.conf
|
||||
|
||||
# Encrypt Dataset
|
||||
skopeo copy \
|
||||
--src-tls-verify=false \
|
||||
--dest-tls-verify=false \
|
||||
--encryption-key "provider:attestation-agent:keypath=/work/dataset.key::keyid=kbs:///default/key/dataset-key::algorithm=A256GCM" \
|
||||
docker://localhost:5000/iris-dataset:v1.0 \
|
||||
docker://localhost:5000/encrypted-iris:v1.0
|
||||
|
||||
# Stop Keyprovider
|
||||
docker stop keyprovider
|
||||
```
|
||||
|
||||
## Running a Computation
|
||||
|
||||
### 1. Start Manager (Host)
|
||||
|
||||
```bash
|
||||
cd /path/to/cocos-ai
|
||||
./build/cocos-manager
|
||||
```
|
||||
|
||||
### 2. Start CVMS Test Server (Host)
|
||||
|
||||
Get your host IP:
|
||||
```bash
|
||||
HOST_IP=$(ip -4 addr show | grep -oP '(?<=inet\s)\d+(\.\d+){3}' | grep -v 127.0.0.1 | head -n1)
|
||||
```
|
||||
|
||||
Start CVMS server:
|
||||
```bash
|
||||
# Calculate SHA3-256 of decrypted files using cocos-cli or cvms-test
|
||||
# NOTE: We use the hash of the original plaintext files, as the Agent validates the decrypted content.
|
||||
# For single files, use the file hash. For directories, use the hash of the directory (which the tools zip deterministically).
|
||||
|
||||
ALGO_HASH=$(./build/cocos-cli checksum lin_reg.py 2>&1 | awk '{print $NF}')
|
||||
|
||||
DATASET_HASH=$(./build/cocos-cli checksum iris.csv 2>&1 | awk '{print $NF}')
|
||||
|
||||
go build -o build/cvms-test ./test/cvms/main.go
|
||||
HOST=$HOST_IP PORT=7001 ./build/cvms-test \
|
||||
-public-key-path ./public.pem \
|
||||
-attested-tls-bool false \
|
||||
-algo-type python \
|
||||
-algo-source-url docker://$HOST_IP:5000/encrypted-lin-reg:v1.0 \
|
||||
-algo-kbs-path default/key/algo-key \
|
||||
-algo-kbs-url http://$HOST_IP:8080 \
|
||||
-algo-hash $ALGO_HASH \
|
||||
-algo-args datasets/dataset_0.csv \
|
||||
-dataset-source-urls docker://$HOST_IP:5000/encrypted-iris:v1.0 \
|
||||
-dataset-kbs-paths default/key/dataset-key \
|
||||
-dataset-kbs-urls http://$HOST_IP:8080 \
|
||||
-dataset-hash $DATASET_HASH
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> You must specify the KBS URL for each encrypted resource using `-algo-kbs-url` and `-dataset-kbs-urls`. A global KBS is no longer supported.
|
||||
|
||||
|
||||
### 3. Create VM via CLI (Host)
|
||||
|
||||
```bash
|
||||
export MANAGER_GRPC_URL=localhost:7002
|
||||
./build/cocos-cli create-vm \
|
||||
--server-url $HOST_IP:7001 \
|
||||
--log-level debug
|
||||
```
|
||||
|
||||
The agent will:
|
||||
1. Receive computation manifest from CVMS
|
||||
2. Use Skopeo to download encrypted OCI images
|
||||
3. Skopeo invokes CoCo Keyprovider via ocicrypt
|
||||
4. CoCo Keyprovider requests decryption key from KBS
|
||||
5. Attestation Agent generates TEE evidence for KBS
|
||||
6. KBS validates evidence and returns decryption key
|
||||
7. Image layers are decrypted and extracted
|
||||
8. Computation executes with decrypted algorithm and dataset
|
||||
|
||||
## Verifying the Setup
|
||||
|
||||
### Check CoCo Keyprovider Status (Inside CVM)
|
||||
|
||||
```bash
|
||||
# SSH into CVM or use console
|
||||
systemctl status coco-keyprovider
|
||||
journalctl -u coco-keyprovider -f
|
||||
```
|
||||
|
||||
### Check Attestation Agent Status
|
||||
|
||||
```bash
|
||||
systemctl status attestation-agent
|
||||
journalctl -u attestation-agent -f
|
||||
```
|
||||
|
||||
### Test Skopeo Decryption Manually
|
||||
|
||||
```bash
|
||||
# Inside CVM
|
||||
export OCICRYPT_KEYPROVIDER_CONFIG=/etc/ocicrypt_keyprovider.conf
|
||||
|
||||
skopeo copy \
|
||||
--src-tls-verify=false \
|
||||
--dest-tls-verify=false \
|
||||
--decryption-key provider:attestation-agent:cc_kbc::null \
|
||||
docker://localhost:5000/encrypted-lin-reg:v1.0 \
|
||||
oci:/tmp/decrypted-algo
|
||||
|
||||
# Verify decryption
|
||||
skopeo inspect oci:/tmp/decrypted-algo | jq -r '.LayersData[].MIMEType'
|
||||
# Should show: application/vnd.oci.image.layer.v1.tar+gzip
|
||||
```
|
||||
|
||||
## Computation Manifest Format
|
||||
|
||||
The CVMS server sends this manifest to the agent:
|
||||
|
||||
```json
|
||||
{
|
||||
"computation_id": "1",
|
||||
"algorithm": {
|
||||
"type": "oci-image",
|
||||
"uri": "docker://localhost:5000/encrypted-lin-reg:v1.0",
|
||||
"encrypted": true,
|
||||
"kbs_resource_path": "default/key/algo-key",
|
||||
"kbs": {
|
||||
"url": "http://192.168.100.15:8080",
|
||||
"enabled": true
|
||||
}
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"filename": "iris.csv",
|
||||
"source": {
|
||||
"type": "oci-image",
|
||||
"url": "docker://localhost:5000/encrypted-iris:v1.0",
|
||||
"encrypted": true,
|
||||
"kbs_resource_path": "default/key/dataset-key"
|
||||
},
|
||||
"kbs": {
|
||||
"url": "http://192.168.100.20:8080",
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
],
|
||||
"kbs": {
|
||||
"url": "http://192.168.100.15:8080",
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### CoCo Keyprovider Not Starting
|
||||
|
||||
```bash
|
||||
# Check logs
|
||||
journalctl -u coco-keyprovider -n 50
|
||||
|
||||
# Verify socket is listening
|
||||
ss -tlnp | grep 50011
|
||||
|
||||
# Check environment
|
||||
cat /etc/default/coco-keyprovider
|
||||
```
|
||||
|
||||
### Skopeo Decryption Fails
|
||||
|
||||
```bash
|
||||
# Verify ocicrypt config
|
||||
cat /etc/ocicrypt_keyprovider.conf
|
||||
|
||||
# Test keyprovider connection
|
||||
grpcurl -plaintext 127.0.0.1:50011 list
|
||||
|
||||
# Check KBS connectivity from CVM
|
||||
curl http://HOST_IP:8080/kbs/v0/auth
|
||||
```
|
||||
|
||||
### KBS Returns 401
|
||||
|
||||
```bash
|
||||
# Check KBS logs on host
|
||||
# Verify attestation evidence format
|
||||
# Ensure KBS is configured for sample attestation
|
||||
```
|
||||
|
||||
## 4. Testing with Non-OCI Sources (S3, HTTP, GCS)
|
||||
|
||||
The `cvms` test utility also supports testing remote encrypted resources hosted in more traditional environments like S3-compatible storage or simple web servers, bypassing the need for container registries and OCI images.
|
||||
|
||||
### Supported Flags
|
||||
|
||||
The following flags define how resources should be fetched:
|
||||
|
||||
- `--algo-source-url`: The URL of the algorithm (e.g. `s3://bucket/algo.bin`, `https://server/algo.bin`)
|
||||
- `--algo-source-type`: The type of remote endpoint (`s3`, `gcs`, `https`, `http`). If omitted, it will automatically be inferred from the URL scheme.
|
||||
- `--algo-kbs-path`: The KBS path to retrieve the AES-256-GCM key from. If present, the agent will attempt decryption.
|
||||
- `--dataset-source-urls` and `--dataset-source-type`: Defines the locations and protocols for datasets.
|
||||
|
||||
### Encryption Format for Non-OCI Sources
|
||||
|
||||
Unlike OCI images where `ocicrypt` wraps the dataset, resources hosted on HTTP/S3 must be straightforwardly encrypted using **AES-256-GCM**.
|
||||
|
||||
The expected format is exactly as produced by standard Go AES-GCM:
|
||||
`nonce (12 bytes) || ciphertext || tag`
|
||||
|
||||
### Test Example
|
||||
|
||||
If you had a Python script encrypted using a key hosted at KBS path `default/my-keys/python-script` and uploaded to `s3://my-secure-bucket/script.enc`, you could run:
|
||||
|
||||
```bash
|
||||
cd test
|
||||
go run cvms/main.go --algo-source-url="s3://my-secure-bucket/script.enc" \
|
||||
--algo-source-type="s3" \
|
||||
--algo-kbs-path="default/my-keys/python-script" \
|
||||
--algo-type="python" \
|
||||
--public-key-path=./test-data/public-key.pem
|
||||
```
|
||||
|
||||
The system will:
|
||||
1. Connect via `attestation-agent` to the KBS to retrieve the symmetric key
|
||||
2. Use Google Cloud Storage client library methods (support for generic S3 via environment variables is standard) to fetch the resource
|
||||
3. Decrypt using AES-256-GCM
|
||||
4. Run the code normally
|
||||
|
||||
---
|
||||
+255
-70
@@ -3,8 +3,8 @@
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.5
|
||||
// protoc v5.29.0
|
||||
// protoc-gen-go v1.36.11
|
||||
// protoc v6.33.1
|
||||
// source: agent/agent.proto
|
||||
|
||||
package agent
|
||||
@@ -384,53 +384,230 @@ func (x *AttestationResponse) GetFile() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
type IMAMeasurementsRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *IMAMeasurementsRequest) Reset() {
|
||||
*x = IMAMeasurementsRequest{}
|
||||
mi := &file_agent_agent_proto_msgTypes[8]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *IMAMeasurementsRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*IMAMeasurementsRequest) ProtoMessage() {}
|
||||
|
||||
func (x *IMAMeasurementsRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[8]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use IMAMeasurementsRequest.ProtoReflect.Descriptor instead.
|
||||
func (*IMAMeasurementsRequest) Descriptor() ([]byte, []int) {
|
||||
return file_agent_agent_proto_rawDescGZIP(), []int{8}
|
||||
}
|
||||
|
||||
type IMAMeasurementsResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
|
||||
Pcr10 []byte `protobuf:"bytes,2,opt,name=pcr10,proto3" json:"pcr10,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *IMAMeasurementsResponse) Reset() {
|
||||
*x = IMAMeasurementsResponse{}
|
||||
mi := &file_agent_agent_proto_msgTypes[9]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *IMAMeasurementsResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*IMAMeasurementsResponse) ProtoMessage() {}
|
||||
|
||||
func (x *IMAMeasurementsResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[9]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use IMAMeasurementsResponse.ProtoReflect.Descriptor instead.
|
||||
func (*IMAMeasurementsResponse) Descriptor() ([]byte, []int) {
|
||||
return file_agent_agent_proto_rawDescGZIP(), []int{9}
|
||||
}
|
||||
|
||||
func (x *IMAMeasurementsResponse) GetFile() []byte {
|
||||
if x != nil {
|
||||
return x.File
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *IMAMeasurementsResponse) GetPcr10() []byte {
|
||||
if x != nil {
|
||||
return x.Pcr10
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type AttestationTokenRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
TokenNonce []byte `protobuf:"bytes,1,opt,name=tokenNonce,proto3" json:"tokenNonce,omitempty"` // Should be less or equal 32 bytes
|
||||
Type int32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AttestationTokenRequest) Reset() {
|
||||
*x = AttestationTokenRequest{}
|
||||
mi := &file_agent_agent_proto_msgTypes[10]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AttestationTokenRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*AttestationTokenRequest) ProtoMessage() {}
|
||||
|
||||
func (x *AttestationTokenRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[10]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use AttestationTokenRequest.ProtoReflect.Descriptor instead.
|
||||
func (*AttestationTokenRequest) Descriptor() ([]byte, []int) {
|
||||
return file_agent_agent_proto_rawDescGZIP(), []int{10}
|
||||
}
|
||||
|
||||
func (x *AttestationTokenRequest) GetTokenNonce() []byte {
|
||||
if x != nil {
|
||||
return x.TokenNonce
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *AttestationTokenRequest) GetType() int32 {
|
||||
if x != nil {
|
||||
return x.Type
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type AttestationTokenResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AttestationTokenResponse) Reset() {
|
||||
*x = AttestationTokenResponse{}
|
||||
mi := &file_agent_agent_proto_msgTypes[11]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AttestationTokenResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*AttestationTokenResponse) ProtoMessage() {}
|
||||
|
||||
func (x *AttestationTokenResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[11]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use AttestationTokenResponse.ProtoReflect.Descriptor instead.
|
||||
func (*AttestationTokenResponse) Descriptor() ([]byte, []int) {
|
||||
return file_agent_agent_proto_rawDescGZIP(), []int{11}
|
||||
}
|
||||
|
||||
func (x *AttestationTokenResponse) GetFile() []byte {
|
||||
if x != nil {
|
||||
return x.File
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_agent_agent_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_agent_agent_proto_rawDesc = string([]byte{
|
||||
0x0a, 0x11, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x12, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x22, 0x4f, 0x0a, 0x0b, 0x41, 0x6c,
|
||||
0x67, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6c, 0x67,
|
||||
0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x61, 0x6c,
|
||||
0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x71, 0x75, 0x69,
|
||||
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x72,
|
||||
0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x22, 0x0e, 0x0a, 0x0c, 0x41,
|
||||
0x6c, 0x67, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x43, 0x0a, 0x0b, 0x44,
|
||||
0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x61,
|
||||
0x74, 0x61, 0x73, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x64, 0x61, 0x74,
|
||||
0x61, 0x73, 0x65, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x6e, 0x61, 0x6d, 0x65,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x6e, 0x61, 0x6d, 0x65,
|
||||
0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x22, 0x0f, 0x0a, 0x0d, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x22, 0x24, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||
0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28,
|
||||
0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x62, 0x0a, 0x12, 0x41, 0x74, 0x74, 0x65, 0x73,
|
||||
0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a,
|
||||
0x08, 0x74, 0x65, 0x65, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52,
|
||||
0x08, 0x74, 0x65, 0x65, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x76, 0x74, 0x70,
|
||||
0x6d, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x76, 0x74,
|
||||
0x70, 0x6d, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18,
|
||||
0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x29, 0x0a, 0x13, 0x41,
|
||||
0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
|
||||
0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67, 0x65, 0x6e, 0x74,
|
||||
0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c, 0x67, 0x6f, 0x12,
|
||||
0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, 0x33, 0x0a, 0x04,
|
||||
0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74,
|
||||
0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74,
|
||||
0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28,
|
||||
0x01, 0x12, 0x39, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14, 0x2e, 0x61, 0x67,
|
||||
0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x48, 0x0a, 0x0b,
|
||||
0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x2e, 0x61, 0x67,
|
||||
0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41,
|
||||
0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, 0x67, 0x65, 0x6e,
|
||||
0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
})
|
||||
const file_agent_agent_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x11agent/agent.proto\x12\x05agent\"O\n" +
|
||||
"\vAlgoRequest\x12\x1c\n" +
|
||||
"\talgorithm\x18\x01 \x01(\fR\talgorithm\x12\"\n" +
|
||||
"\frequirements\x18\x02 \x01(\fR\frequirements\"\x0e\n" +
|
||||
"\fAlgoResponse\"C\n" +
|
||||
"\vDataRequest\x12\x18\n" +
|
||||
"\adataset\x18\x01 \x01(\fR\adataset\x12\x1a\n" +
|
||||
"\bfilename\x18\x02 \x01(\tR\bfilename\"\x0e\n" +
|
||||
"\fDataResponse\"\x0f\n" +
|
||||
"\rResultRequest\"$\n" +
|
||||
"\x0eResultResponse\x12\x12\n" +
|
||||
"\x04file\x18\x01 \x01(\fR\x04file\"b\n" +
|
||||
"\x12AttestationRequest\x12\x1a\n" +
|
||||
"\bteeNonce\x18\x01 \x01(\fR\bteeNonce\x12\x1c\n" +
|
||||
"\tvtpmNonce\x18\x02 \x01(\fR\tvtpmNonce\x12\x12\n" +
|
||||
"\x04type\x18\x03 \x01(\x05R\x04type\")\n" +
|
||||
"\x13AttestationResponse\x12\x12\n" +
|
||||
"\x04file\x18\x01 \x01(\fR\x04file\"\x18\n" +
|
||||
"\x16IMAMeasurementsRequest\"C\n" +
|
||||
"\x17IMAMeasurementsResponse\x12\x12\n" +
|
||||
"\x04file\x18\x01 \x01(\fR\x04file\x12\x14\n" +
|
||||
"\x05pcr10\x18\x02 \x01(\fR\x05pcr10\"M\n" +
|
||||
"\x17AttestationTokenRequest\x12\x1e\n" +
|
||||
"\n" +
|
||||
"tokenNonce\x18\x01 \x01(\fR\n" +
|
||||
"tokenNonce\x12\x12\n" +
|
||||
"\x04type\x18\x03 \x01(\x05R\x04type\".\n" +
|
||||
"\x18AttestationTokenResponse\x12\x12\n" +
|
||||
"\x04file\x18\x01 \x01(\fR\x04file2\xaf\x03\n" +
|
||||
"\fAgentService\x123\n" +
|
||||
"\x04Algo\x12\x12.agent.AlgoRequest\x1a\x13.agent.AlgoResponse\"\x00(\x01\x123\n" +
|
||||
"\x04Data\x12\x12.agent.DataRequest\x1a\x13.agent.DataResponse\"\x00(\x01\x129\n" +
|
||||
"\x06Result\x12\x14.agent.ResultRequest\x1a\x15.agent.ResultResponse\"\x000\x01\x12H\n" +
|
||||
"\vAttestation\x12\x19.agent.AttestationRequest\x1a\x1a.agent.AttestationResponse\"\x000\x01\x12T\n" +
|
||||
"\x0fIMAMeasurements\x12\x1d.agent.IMAMeasurementsRequest\x1a\x1e.agent.IMAMeasurementsResponse\"\x000\x01\x12Z\n" +
|
||||
"\x15AzureAttestationToken\x12\x1e.agent.AttestationTokenRequest\x1a\x1f.agent.AttestationTokenResponse\"\x00B\tZ\a./agentb\x06proto3"
|
||||
|
||||
var (
|
||||
file_agent_agent_proto_rawDescOnce sync.Once
|
||||
@@ -444,31 +621,39 @@ func file_agent_agent_proto_rawDescGZIP() []byte {
|
||||
return file_agent_agent_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_agent_proto_msgTypes = make([]protoimpl.MessageInfo, 8)
|
||||
var file_agent_agent_proto_msgTypes = make([]protoimpl.MessageInfo, 12)
|
||||
var file_agent_agent_proto_goTypes = []any{
|
||||
(*AlgoRequest)(nil), // 0: agent.AlgoRequest
|
||||
(*AlgoResponse)(nil), // 1: agent.AlgoResponse
|
||||
(*DataRequest)(nil), // 2: agent.DataRequest
|
||||
(*DataResponse)(nil), // 3: agent.DataResponse
|
||||
(*ResultRequest)(nil), // 4: agent.ResultRequest
|
||||
(*ResultResponse)(nil), // 5: agent.ResultResponse
|
||||
(*AttestationRequest)(nil), // 6: agent.AttestationRequest
|
||||
(*AttestationResponse)(nil), // 7: agent.AttestationResponse
|
||||
(*AlgoRequest)(nil), // 0: agent.AlgoRequest
|
||||
(*AlgoResponse)(nil), // 1: agent.AlgoResponse
|
||||
(*DataRequest)(nil), // 2: agent.DataRequest
|
||||
(*DataResponse)(nil), // 3: agent.DataResponse
|
||||
(*ResultRequest)(nil), // 4: agent.ResultRequest
|
||||
(*ResultResponse)(nil), // 5: agent.ResultResponse
|
||||
(*AttestationRequest)(nil), // 6: agent.AttestationRequest
|
||||
(*AttestationResponse)(nil), // 7: agent.AttestationResponse
|
||||
(*IMAMeasurementsRequest)(nil), // 8: agent.IMAMeasurementsRequest
|
||||
(*IMAMeasurementsResponse)(nil), // 9: agent.IMAMeasurementsResponse
|
||||
(*AttestationTokenRequest)(nil), // 10: agent.AttestationTokenRequest
|
||||
(*AttestationTokenResponse)(nil), // 11: agent.AttestationTokenResponse
|
||||
}
|
||||
var file_agent_agent_proto_depIdxs = []int32{
|
||||
0, // 0: agent.AgentService.Algo:input_type -> agent.AlgoRequest
|
||||
2, // 1: agent.AgentService.Data:input_type -> agent.DataRequest
|
||||
4, // 2: agent.AgentService.Result:input_type -> agent.ResultRequest
|
||||
6, // 3: agent.AgentService.Attestation:input_type -> agent.AttestationRequest
|
||||
1, // 4: agent.AgentService.Algo:output_type -> agent.AlgoResponse
|
||||
3, // 5: agent.AgentService.Data:output_type -> agent.DataResponse
|
||||
5, // 6: agent.AgentService.Result:output_type -> agent.ResultResponse
|
||||
7, // 7: agent.AgentService.Attestation:output_type -> agent.AttestationResponse
|
||||
4, // [4:8] is the sub-list for method output_type
|
||||
0, // [0:4] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
0, // 0: agent.AgentService.Algo:input_type -> agent.AlgoRequest
|
||||
2, // 1: agent.AgentService.Data:input_type -> agent.DataRequest
|
||||
4, // 2: agent.AgentService.Result:input_type -> agent.ResultRequest
|
||||
6, // 3: agent.AgentService.Attestation:input_type -> agent.AttestationRequest
|
||||
8, // 4: agent.AgentService.IMAMeasurements:input_type -> agent.IMAMeasurementsRequest
|
||||
10, // 5: agent.AgentService.AzureAttestationToken:input_type -> agent.AttestationTokenRequest
|
||||
1, // 6: agent.AgentService.Algo:output_type -> agent.AlgoResponse
|
||||
3, // 7: agent.AgentService.Data:output_type -> agent.DataResponse
|
||||
5, // 8: agent.AgentService.Result:output_type -> agent.ResultResponse
|
||||
7, // 9: agent.AgentService.Attestation:output_type -> agent.AttestationResponse
|
||||
9, // 10: agent.AgentService.IMAMeasurements:output_type -> agent.IMAMeasurementsResponse
|
||||
11, // 11: agent.AgentService.AzureAttestationToken:output_type -> agent.AttestationTokenResponse
|
||||
6, // [6:12] is the sub-list for method output_type
|
||||
0, // [0:6] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_agent_proto_init() }
|
||||
@@ -482,7 +667,7 @@ func file_agent_agent_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_agent_proto_rawDesc), len(file_agent_agent_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 8,
|
||||
NumMessages: 12,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -12,6 +12,8 @@ service AgentService {
|
||||
rpc Data(stream DataRequest) returns (DataResponse) {}
|
||||
rpc Result(ResultRequest) returns (stream ResultResponse) {}
|
||||
rpc Attestation(AttestationRequest) returns (stream AttestationResponse) {}
|
||||
rpc IMAMeasurements(IMAMeasurementsRequest) returns (stream IMAMeasurementsResponse) {}
|
||||
rpc AzureAttestationToken(AttestationTokenRequest) returns (AttestationTokenResponse) {}
|
||||
}
|
||||
|
||||
message AlgoRequest {
|
||||
@@ -44,3 +46,19 @@ message AttestationRequest {
|
||||
message AttestationResponse {
|
||||
bytes file = 1;
|
||||
}
|
||||
|
||||
message IMAMeasurementsRequest {
|
||||
}
|
||||
|
||||
message IMAMeasurementsResponse {
|
||||
bytes file = 1;
|
||||
bytes pcr10 = 2;
|
||||
}
|
||||
|
||||
message AttestationTokenRequest{
|
||||
bytes tokenNonce = 1; // Should be less or equal 32 bytes
|
||||
int32 type = 3;
|
||||
}
|
||||
message AttestationTokenResponse{
|
||||
bytes file = 1;
|
||||
}
|
||||
|
||||
+92
-12
@@ -3,8 +3,8 @@
|
||||
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v5.29.0
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc v6.33.1
|
||||
// source: agent/agent.proto
|
||||
|
||||
package agent
|
||||
@@ -22,10 +22,12 @@ import (
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
AgentService_Algo_FullMethodName = "/agent.AgentService/Algo"
|
||||
AgentService_Data_FullMethodName = "/agent.AgentService/Data"
|
||||
AgentService_Result_FullMethodName = "/agent.AgentService/Result"
|
||||
AgentService_Attestation_FullMethodName = "/agent.AgentService/Attestation"
|
||||
AgentService_Algo_FullMethodName = "/agent.AgentService/Algo"
|
||||
AgentService_Data_FullMethodName = "/agent.AgentService/Data"
|
||||
AgentService_Result_FullMethodName = "/agent.AgentService/Result"
|
||||
AgentService_Attestation_FullMethodName = "/agent.AgentService/Attestation"
|
||||
AgentService_IMAMeasurements_FullMethodName = "/agent.AgentService/IMAMeasurements"
|
||||
AgentService_AzureAttestationToken_FullMethodName = "/agent.AgentService/AzureAttestationToken"
|
||||
)
|
||||
|
||||
// AgentServiceClient is the client API for AgentService service.
|
||||
@@ -36,6 +38,8 @@ type AgentServiceClient interface {
|
||||
Data(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[DataRequest, DataResponse], error)
|
||||
Result(ctx context.Context, in *ResultRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ResultResponse], error)
|
||||
Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[AttestationResponse], error)
|
||||
IMAMeasurements(ctx context.Context, in *IMAMeasurementsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[IMAMeasurementsResponse], error)
|
||||
AzureAttestationToken(ctx context.Context, in *AttestationTokenRequest, opts ...grpc.CallOption) (*AttestationTokenResponse, error)
|
||||
}
|
||||
|
||||
type agentServiceClient struct {
|
||||
@@ -110,6 +114,35 @@ func (c *agentServiceClient) Attestation(ctx context.Context, in *AttestationReq
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type AgentService_AttestationClient = grpc.ServerStreamingClient[AttestationResponse]
|
||||
|
||||
func (c *agentServiceClient) IMAMeasurements(ctx context.Context, in *IMAMeasurementsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[IMAMeasurementsResponse], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &AgentService_ServiceDesc.Streams[4], AgentService_IMAMeasurements_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &grpc.GenericClientStream[IMAMeasurementsRequest, IMAMeasurementsResponse]{ClientStream: stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type AgentService_IMAMeasurementsClient = grpc.ServerStreamingClient[IMAMeasurementsResponse]
|
||||
|
||||
func (c *agentServiceClient) AzureAttestationToken(ctx context.Context, in *AttestationTokenRequest, opts ...grpc.CallOption) (*AttestationTokenResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(AttestationTokenResponse)
|
||||
err := c.cc.Invoke(ctx, AgentService_AzureAttestationToken_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// AgentServiceServer is the server API for AgentService service.
|
||||
// All implementations must embed UnimplementedAgentServiceServer
|
||||
// for forward compatibility.
|
||||
@@ -118,6 +151,8 @@ type AgentServiceServer interface {
|
||||
Data(grpc.ClientStreamingServer[DataRequest, DataResponse]) error
|
||||
Result(*ResultRequest, grpc.ServerStreamingServer[ResultResponse]) error
|
||||
Attestation(*AttestationRequest, grpc.ServerStreamingServer[AttestationResponse]) error
|
||||
IMAMeasurements(*IMAMeasurementsRequest, grpc.ServerStreamingServer[IMAMeasurementsResponse]) error
|
||||
AzureAttestationToken(context.Context, *AttestationTokenRequest) (*AttestationTokenResponse, error)
|
||||
mustEmbedUnimplementedAgentServiceServer()
|
||||
}
|
||||
|
||||
@@ -129,16 +164,22 @@ type AgentServiceServer interface {
|
||||
type UnimplementedAgentServiceServer struct{}
|
||||
|
||||
func (UnimplementedAgentServiceServer) Algo(grpc.ClientStreamingServer[AlgoRequest, AlgoResponse]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Algo not implemented")
|
||||
return status.Error(codes.Unimplemented, "method Algo not implemented")
|
||||
}
|
||||
func (UnimplementedAgentServiceServer) Data(grpc.ClientStreamingServer[DataRequest, DataResponse]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Data not implemented")
|
||||
return status.Error(codes.Unimplemented, "method Data not implemented")
|
||||
}
|
||||
func (UnimplementedAgentServiceServer) Result(*ResultRequest, grpc.ServerStreamingServer[ResultResponse]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Result not implemented")
|
||||
return status.Error(codes.Unimplemented, "method Result not implemented")
|
||||
}
|
||||
func (UnimplementedAgentServiceServer) Attestation(*AttestationRequest, grpc.ServerStreamingServer[AttestationResponse]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Attestation not implemented")
|
||||
return status.Error(codes.Unimplemented, "method Attestation not implemented")
|
||||
}
|
||||
func (UnimplementedAgentServiceServer) IMAMeasurements(*IMAMeasurementsRequest, grpc.ServerStreamingServer[IMAMeasurementsResponse]) error {
|
||||
return status.Error(codes.Unimplemented, "method IMAMeasurements not implemented")
|
||||
}
|
||||
func (UnimplementedAgentServiceServer) AzureAttestationToken(context.Context, *AttestationTokenRequest) (*AttestationTokenResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method AzureAttestationToken not implemented")
|
||||
}
|
||||
func (UnimplementedAgentServiceServer) mustEmbedUnimplementedAgentServiceServer() {}
|
||||
func (UnimplementedAgentServiceServer) testEmbeddedByValue() {}
|
||||
@@ -151,7 +192,7 @@ type UnsafeAgentServiceServer interface {
|
||||
}
|
||||
|
||||
func RegisterAgentServiceServer(s grpc.ServiceRegistrar, srv AgentServiceServer) {
|
||||
// If the following call pancis, it indicates UnimplementedAgentServiceServer was
|
||||
// If the following call panics, it indicates UnimplementedAgentServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
@@ -197,13 +238,47 @@ func _AgentService_Attestation_Handler(srv interface{}, stream grpc.ServerStream
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type AgentService_AttestationServer = grpc.ServerStreamingServer[AttestationResponse]
|
||||
|
||||
func _AgentService_IMAMeasurements_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(IMAMeasurementsRequest)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(AgentServiceServer).IMAMeasurements(m, &grpc.GenericServerStream[IMAMeasurementsRequest, IMAMeasurementsResponse]{ServerStream: stream})
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type AgentService_IMAMeasurementsServer = grpc.ServerStreamingServer[IMAMeasurementsResponse]
|
||||
|
||||
func _AgentService_AzureAttestationToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(AttestationTokenRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(AgentServiceServer).AzureAttestationToken(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: AgentService_AzureAttestationToken_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(AgentServiceServer).AzureAttestationToken(ctx, req.(*AttestationTokenRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// AgentService_ServiceDesc is the grpc.ServiceDesc for AgentService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var AgentService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "agent.AgentService",
|
||||
HandlerType: (*AgentServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "AzureAttestationToken",
|
||||
Handler: _AgentService_AzureAttestationToken_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Algo",
|
||||
@@ -225,6 +300,11 @@ var AgentService_ServiceDesc = grpc.ServiceDesc{
|
||||
Handler: _AgentService_Attestation_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
{
|
||||
StreamName: "IMAMeasurements",
|
||||
Handler: _AgentService_IMAMeasurements_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "agent/agent.proto",
|
||||
}
|
||||
|
||||
@@ -3,16 +3,21 @@
|
||||
package binary
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
)
|
||||
|
||||
var execCommand = exec.Command
|
||||
|
||||
var _ algorithm.Algorithm = (*binary)(nil)
|
||||
|
||||
type binary struct {
|
||||
@@ -21,6 +26,7 @@ type binary struct {
|
||||
stdout io.Writer
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string, cmpID string) algorithm.Algorithm {
|
||||
@@ -33,13 +39,16 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
|
||||
}
|
||||
|
||||
func (b *binary) Run() error {
|
||||
b.cmd = exec.Command(b.algoFile, b.args...)
|
||||
b.mu.Lock()
|
||||
b.cmd = execCommand(b.algoFile, b.args...)
|
||||
b.cmd.Stderr = b.stderr
|
||||
b.cmd.Stdout = b.stdout
|
||||
|
||||
if err := b.cmd.Start(); err != nil {
|
||||
b.mu.Unlock()
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
if err := b.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
@@ -49,15 +58,18 @@ func (b *binary) Run() error {
|
||||
}
|
||||
|
||||
func (b *binary) Stop() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.cmd.ProcessState != nil && b.cmd.ProcessState.Exited() {
|
||||
if b.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := b.cmd.Process.Kill(); err != nil {
|
||||
if err := b.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,14 @@ package binary
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
@@ -73,6 +77,7 @@ func TestBinaryRun(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args, "").(*binary)
|
||||
|
||||
@@ -98,3 +103,68 @@ func TestBinaryRun(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
t.Run("stop nil cmd", func(t *testing.T) {
|
||||
b := &binary{}
|
||||
err := b.Stop()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("stop with running process", func(t *testing.T) {
|
||||
b := &binary{
|
||||
algoFile: "sleep",
|
||||
args: []string{"10"},
|
||||
}
|
||||
if err := b.Run(); err != nil {
|
||||
t.Fatalf("Failed to start command: %v", err)
|
||||
}
|
||||
|
||||
err := b.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify it actually stopped
|
||||
_ = b.cmd.Wait()
|
||||
})
|
||||
|
||||
t.Run("stop already exited", func(t *testing.T) {
|
||||
b := &binary{
|
||||
algoFile: "echo",
|
||||
args: []string{"test"},
|
||||
stdout: io.Discard,
|
||||
stderr: io.Discard,
|
||||
}
|
||||
if err := b.Run(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := b.Stop()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunError(t *testing.T) {
|
||||
// Mock execCommand to return an error on Start
|
||||
oldExecCommand := execCommand
|
||||
execCommand = mockExecCommandError
|
||||
defer func() { execCommand = oldExecCommand }()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
b := NewAlgorithm(logger, eventsSvc, "test", nil, "").(*binary)
|
||||
|
||||
err := b.Run()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func mockExecCommandError(command string, args ...string) *exec.Cmd {
|
||||
// This will make Start() fail if we use a non-existent binary
|
||||
return exec.Command("non_existent_binary_for_sure_12345")
|
||||
}
|
||||
|
||||
func TestHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ type docker struct {
|
||||
logger *slog.Logger
|
||||
stderr io.Writer
|
||||
stdout io.Writer
|
||||
cmpID string
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile, cmpID string) algorithm.Algorithm {
|
||||
@@ -41,6 +42,7 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile, cmpID
|
||||
logger: logger,
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
cmpID: cmpID,
|
||||
}
|
||||
|
||||
return d
|
||||
@@ -107,7 +109,7 @@ func (d *docker) Run() error {
|
||||
Target: resultsMountPath,
|
||||
},
|
||||
},
|
||||
}, nil, nil, containerName)
|
||||
}, nil, nil, fmt.Sprintf("%s-%s", containerName, d.cmpID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create a Docker container: %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewAlgorithm creates a new instance of Algorithm. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAlgorithm(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Algorithm {
|
||||
mock := &Algorithm{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Algorithm is an autogenerated mock type for the Algorithm type
|
||||
type Algorithm struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Algorithm_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Algorithm) EXPECT() *Algorithm_Expecter {
|
||||
return &Algorithm_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Run provides a mock function for the type Algorithm
|
||||
func (_mock *Algorithm) Run() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Run")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Algorithm_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
|
||||
type Algorithm_Run_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Run is a helper method to define mock.On call
|
||||
func (_e *Algorithm_Expecter) Run() *Algorithm_Run_Call {
|
||||
return &Algorithm_Run_Call{Call: _e.mock.On("Run")}
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Run_Call) Run(run func()) *Algorithm_Run_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Run_Call) Return(err error) *Algorithm_Run_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Run_Call) RunAndReturn(run func() error) *Algorithm_Run_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function for the type Algorithm
|
||||
func (_mock *Algorithm) Stop() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Algorithm_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type Algorithm_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
func (_e *Algorithm_Expecter) Stop() *Algorithm_Stop_Call {
|
||||
return &Algorithm_Stop_Call{Call: _e.mock.On("Stop")}
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Stop_Call) Run(run func()) *Algorithm_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Stop_Call) Return(err error) *Algorithm_Stop_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Stop_Call) RunAndReturn(run func() error) *Algorithm_Stop_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -4,12 +4,14 @@ package python
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
@@ -40,6 +42,7 @@ type python struct {
|
||||
requirementsFile string
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string, cmpID string) algorithm.Algorithm {
|
||||
@@ -60,6 +63,12 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requir
|
||||
|
||||
func (p *python) Run() error {
|
||||
venvPath := "venv"
|
||||
defer func() {
|
||||
if err := os.RemoveAll(venvPath); err != nil {
|
||||
_, _ = p.stderr.Write([]byte(fmt.Sprintf("error removing virtual environment: %v\n", err)))
|
||||
}
|
||||
}()
|
||||
|
||||
createVenvCmd := exec.Command(p.runtime, "-m", "venv", venvPath)
|
||||
createVenvCmd.Stderr = p.stderr
|
||||
createVenvCmd.Stdout = p.stdout
|
||||
@@ -69,11 +78,11 @@ func (p *python) Run() error {
|
||||
|
||||
pythonPath := filepath.Join(venvPath, "bin", "python")
|
||||
|
||||
updatePipCmd := exec.Command(pythonPath, "-m", "pip", "install", "--upgrade", "pip")
|
||||
updatePipCmd := exec.Command(pythonPath, "-m", "pip", "install", "--upgrade", "pip", "setuptools", "wheel")
|
||||
updatePipCmd.Stderr = p.stderr
|
||||
updatePipCmd.Stdout = p.stdout
|
||||
if err := updatePipCmd.Run(); err != nil {
|
||||
return fmt.Errorf("error updating pip: %v", err)
|
||||
return fmt.Errorf("error updating pip, setuptools and wheel: %v", err)
|
||||
}
|
||||
|
||||
if p.requirementsFile != "" {
|
||||
@@ -86,35 +95,37 @@ func (p *python) Run() error {
|
||||
}
|
||||
|
||||
args := append([]string{p.algoFile}, p.args...)
|
||||
p.mu.Lock()
|
||||
p.cmd = exec.Command(pythonPath, args...)
|
||||
p.cmd.Stderr = p.stderr
|
||||
p.cmd.Stdout = p.stdout
|
||||
|
||||
if err := p.cmd.Start(); err != nil {
|
||||
p.mu.Unlock()
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if err := p.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(venvPath); err != nil {
|
||||
return fmt.Errorf("error removing virtual environment: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *python) Stop() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.cmd.ProcessState != nil && p.cmd.ProcessState.Exited() {
|
||||
if p.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.cmd.Process.Kill(); err != nil {
|
||||
if err := p.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,10 +8,14 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -85,6 +89,7 @@ func TestRun(t *testing.T) {
|
||||
}
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
@@ -126,6 +131,7 @@ func TestRunWithRequirements(t *testing.T) {
|
||||
}
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
@@ -146,3 +152,91 @@ func TestRunWithRequirements(t *testing.T) {
|
||||
t.Errorf("Expected output to contain requests version 2.26.0, got %q", stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
t.Run("stop nil cmd", func(t *testing.T) {
|
||||
p := &python{}
|
||||
err := p.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop with running process", func(t *testing.T) {
|
||||
p := &python{
|
||||
stderr: io.Discard,
|
||||
stdout: io.Discard,
|
||||
}
|
||||
|
||||
p.cmd = exec.Command("python3", "-c", "import time; time.sleep(10)")
|
||||
if err := p.cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start command: %v", err)
|
||||
}
|
||||
|
||||
err := p.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
|
||||
// Verify it actually stopped
|
||||
_ = p.cmd.Wait()
|
||||
})
|
||||
|
||||
t.Run("stop already exited", func(t *testing.T) {
|
||||
p := &python{}
|
||||
p.cmd = exec.Command("python3", "-c", "print(1)")
|
||||
if err := p.cmd.Run(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := p.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRun_Errors(t *testing.T) {
|
||||
t.Run("invalid runtime error", func(t *testing.T) {
|
||||
algo := &python{
|
||||
algoFile: "algo.py",
|
||||
runtime: "non-existent-python",
|
||||
stderr: io.Discard,
|
||||
stdout: io.Discard,
|
||||
}
|
||||
err := algo.Run()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error creating virtual environment")
|
||||
})
|
||||
|
||||
t.Run("pip install failure", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "python-err-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
scriptPath := filepath.Join(tmpDir, "test.py")
|
||||
require.NoError(t, os.WriteFile(scriptPath, []byte("print(1)"), 0o644))
|
||||
|
||||
reqPath := filepath.Join(tmpDir, "requirements.txt")
|
||||
require.NoError(t, os.WriteFile(reqPath, []byte("non-existent-package==9.9.9"), 0o644))
|
||||
|
||||
algo := &python{
|
||||
algoFile: scriptPath,
|
||||
requirementsFile: reqPath,
|
||||
runtime: "python3",
|
||||
stderr: io.Discard,
|
||||
stdout: io.Discard,
|
||||
}
|
||||
err = algo.Run()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error installing requirements")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewAlgorithmEmptyRuntime(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
algo := NewAlgorithm(slog.Default(), eventsSvc, "", "req.txt", "algo.py", nil, "")
|
||||
p := algo.(*python)
|
||||
if p.runtime != PyRuntime {
|
||||
t.Errorf("Expected default runtime %s, got %s", PyRuntime, p.runtime)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,21 @@
|
||||
package wasm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
)
|
||||
|
||||
var execCommand = exec.Command
|
||||
|
||||
const wasmRuntime = "wasmedge"
|
||||
|
||||
var mapDirOption = []string{"--dir", ".:" + algorithm.ResultsDir}
|
||||
@@ -25,6 +30,7 @@ type wasm struct {
|
||||
stdout io.Writer
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, args []string, algoFile, cmpID string) algorithm.Algorithm {
|
||||
@@ -39,13 +45,16 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, args []string,
|
||||
func (w *wasm) Run() error {
|
||||
args := append(mapDirOption, w.algoFile)
|
||||
args = append(args, w.args...)
|
||||
w.cmd = exec.Command(wasmRuntime, args...)
|
||||
w.mu.Lock()
|
||||
w.cmd = execCommand(wasmRuntime, args...)
|
||||
w.cmd.Stderr = w.stderr
|
||||
w.cmd.Stdout = w.stdout
|
||||
|
||||
if err := w.cmd.Start(); err != nil {
|
||||
w.mu.Unlock()
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
if err := w.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
@@ -55,15 +64,18 @@ func (w *wasm) Run() error {
|
||||
}
|
||||
|
||||
func (w *wasm) Stop() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.cmd.ProcessState != nil && w.cmd.ProcessState.Exited() {
|
||||
if w.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := w.cmd.Process.Kill(); err != nil {
|
||||
if err := w.cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,15 +7,18 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
const testWasm = "test.wasm"
|
||||
|
||||
func TestNewAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "test.wasm"
|
||||
algoFile := testWasm
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, args, algoFile, "")
|
||||
@@ -49,14 +52,18 @@ func TestRunError(t *testing.T) {
|
||||
execCommand = mockExecCommandError
|
||||
defer func() { execCommand = exec.Command }()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "test.wasm"
|
||||
algoFile := testWasm
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
w := NewAlgorithm(logger, eventsSvc, args, algoFile, "").(*wasm)
|
||||
w := &wasm{
|
||||
algoFile: algoFile,
|
||||
args: args,
|
||||
stderr: os.Stderr, // Use real stderr or io.Discard
|
||||
stdout: os.Stdout,
|
||||
}
|
||||
|
||||
err := w.Run()
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Run() should have returned an error")
|
||||
}
|
||||
@@ -76,14 +83,97 @@ func mockExecCommandError(command string, args ...string) *exec.Cmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
t.Run("stop nil cmd", func(t *testing.T) {
|
||||
w := &wasm{}
|
||||
err := w.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop with running process", func(t *testing.T) {
|
||||
oldExecCommand := execCommand
|
||||
execCommand = mockExecCommand
|
||||
defer func() { execCommand = oldExecCommand }()
|
||||
|
||||
w := &wasm{
|
||||
algoFile: testWasm,
|
||||
stdout: os.Stdout,
|
||||
stderr: os.Stderr,
|
||||
}
|
||||
|
||||
// We need to simulate a running process.
|
||||
// mockExecCommand returns a command that runs TestHelperProcess.
|
||||
// If we don't call Wait(), it keeps running? No, TestHelperProcess exits immediately.
|
||||
// Let's modify TestHelperProcess to sleep if an env var is set.
|
||||
|
||||
w.cmd = mockExecCommand("sleep", "10")
|
||||
w.cmd.Env = append(w.cmd.Env, "GO_WANT_HELPER_PROCESS_SLEEP=1")
|
||||
if err := w.cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start command: %v", err)
|
||||
}
|
||||
|
||||
err := w.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
_ = w.cmd.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopAlreadyExited(t *testing.T) {
|
||||
oldExecCommand := execCommand
|
||||
execCommand = mockExecCommand
|
||||
defer func() { execCommand = oldExecCommand }()
|
||||
|
||||
w := &wasm{
|
||||
algoFile: testWasm,
|
||||
stdout: os.Stdout,
|
||||
stderr: os.Stderr,
|
||||
}
|
||||
|
||||
w.cmd = mockExecCommand("true")
|
||||
if err := w.cmd.Run(); err != nil {
|
||||
t.Fatalf("Failed to run command: %v", err)
|
||||
}
|
||||
|
||||
err := w.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSuccess(t *testing.T) {
|
||||
oldExecCommand := execCommand
|
||||
execCommand = mockExecCommand
|
||||
defer func() { execCommand = oldExecCommand }()
|
||||
|
||||
algoFile := testWasm
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
w := &wasm{
|
||||
algoFile: algoFile,
|
||||
args: args,
|
||||
stderr: os.Stderr,
|
||||
stdout: os.Stdout,
|
||||
}
|
||||
|
||||
err := w.Run()
|
||||
if err != nil {
|
||||
t.Errorf("Run() returned unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS_SLEEP") == "1" {
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS_ERROR") == "1" {
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
var execCommand = exec.Command
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
func algoEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(algoReq)
|
||||
|
||||
if err := req.validate(); err != nil {
|
||||
@@ -30,7 +30,7 @@ func algoEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
}
|
||||
|
||||
func dataEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(dataReq)
|
||||
|
||||
if err := req.validate(); err != nil {
|
||||
@@ -49,7 +49,7 @@ func dataEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
}
|
||||
|
||||
func resultEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(resultReq)
|
||||
|
||||
if err := req.validate(); err != nil {
|
||||
@@ -65,13 +65,13 @@ func resultEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
}
|
||||
|
||||
func attestationEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(attestationReq)
|
||||
|
||||
if err := req.validate(); err != nil {
|
||||
return attestationRes{}, err
|
||||
}
|
||||
file, err := svc.Attestation(ctx, req.TeeNonce, req.VtpmNonce, config.AttestationType(req.AttType))
|
||||
file, err := svc.Attestation(ctx, req.TeeNonce, req.VtpmNonce, attestation.PlatformType(req.AttType))
|
||||
if err != nil {
|
||||
return attestationRes{}, err
|
||||
}
|
||||
@@ -79,3 +79,33 @@ func attestationEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
return attestationRes{File: file}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func imaMeasurementsEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(imaMeasurementsReq)
|
||||
|
||||
if err := req.validate(); err != nil {
|
||||
return imaMeasurementsRes{}, err
|
||||
}
|
||||
file, pcr10, err := svc.IMAMeasurements(ctx)
|
||||
if err != nil {
|
||||
return imaMeasurementsRes{}, err
|
||||
}
|
||||
|
||||
return imaMeasurementsRes{File: file, PCR10: pcr10}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func azureAttestationTokenEndpoint(svc agent.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(azureAttestationTokenReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return fetchAttestationTokenRes{}, err
|
||||
}
|
||||
file, err := svc.AzureAttestationToken(ctx, req.tokenNonce)
|
||||
if err != nil {
|
||||
return fetchAttestationTokenRes{}, err
|
||||
}
|
||||
return fetchAttestationTokenRes{File: file}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
@@ -142,11 +143,11 @@ func TestAttestationEndpoint(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: config.SNP},
|
||||
req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: attestation.SNP},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: config.SNP},
|
||||
req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: attestation.SNP},
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
@@ -172,3 +173,55 @@ func TestAttestationEndpoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestationTokenEndpoint(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
tests := []struct {
|
||||
name string
|
||||
req azureAttestationTokenReq
|
||||
mockErr error
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
req: azureAttestationTokenReq{tokenNonce: sha3.Sum256([]byte("vtpm nonce"))},
|
||||
mockErr: nil,
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
req: azureAttestationTokenReq{tokenNonce: sha3.Sum256([]byte("vtpm nonce"))},
|
||||
mockErr: errors.New("mock failure"),
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Only call service mock if validation is expected to pass
|
||||
if err := tt.req.validate(); err == nil {
|
||||
svc.On("AzureAttestationToken", mock.Anything, tt.req.tokenNonce).
|
||||
Return([]byte("mock file"), tt.mockErr).Once()
|
||||
}
|
||||
|
||||
endpoint := azureAttestationTokenEndpoint(svc)
|
||||
res, err := endpoint(context.Background(), tt.req)
|
||||
|
||||
if (err != nil) != tt.expectedErr {
|
||||
t.Errorf("attestationTokenEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
|
||||
}
|
||||
|
||||
if !tt.expectedErr {
|
||||
r, ok := res.(fetchAttestationTokenRes)
|
||||
if !ok {
|
||||
t.Errorf("attestationTokenEndpoint() returned unexpected type %T", res)
|
||||
}
|
||||
if string(r.File) != "mock file" {
|
||||
t.Errorf("expected file content 'mock file', got %s", r.File)
|
||||
}
|
||||
}
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func NewAuthInterceptor(authSvc auth.Authenticator) (grpc.UnaryServerInterceptor
|
||||
}
|
||||
|
||||
func (s *authInterceptor) AuthStreamInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
switch info.FullMethod {
|
||||
case agent.AgentService_Algo_FullMethodName:
|
||||
if _, err := s.auth.AuthenticateUser(stream.Context(), auth.AlgorithmProviderRole); err != nil {
|
||||
@@ -59,7 +59,7 @@ func (s *authInterceptor) AuthStreamInterceptor() grpc.StreamServerInterceptor {
|
||||
}
|
||||
|
||||
func (s *authInterceptor) AuthUnaryInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
switch info.FullMethod {
|
||||
case agent.AgentService_Result_FullMethodName:
|
||||
ctx, err := s.auth.AuthenticateUser(ctx, auth.ConsumerRole)
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestAuthUnaryInterceptor(t *testing.T) {
|
||||
}
|
||||
unaryInt, _ := NewAuthInterceptor(authmock)
|
||||
|
||||
_, err := unaryInt(context.Background(), nil, &grpc.UnaryServerInfo{FullMethod: tt.method}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
_, err := unaryInt(context.Background(), nil, &grpc.UnaryServerInfo{FullMethod: tt.method}, func(ctx context.Context, req any) (any, error) {
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
@@ -129,7 +129,7 @@ func TestAuthStreamInterceptor(t *testing.T) {
|
||||
}
|
||||
_, streamInt := NewAuthInterceptor(authmock)
|
||||
|
||||
err := streamInt(nil, &mockServerStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs())}, &grpc.StreamServerInfo{FullMethod: tt.method}, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
err := streamInt(nil, &mockServerStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs())}, &grpc.StreamServerInfo{FullMethod: tt.method}, func(srv any, stream grpc.ServerStream) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ package grpc
|
||||
import (
|
||||
"errors"
|
||||
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ type dataReq struct {
|
||||
|
||||
func (req dataReq) validate() error {
|
||||
if len(req.Dataset) == 0 {
|
||||
return errors.New("dataset CSV file is required")
|
||||
return errors.New("dataset is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -42,16 +41,35 @@ func (req resultReq) validate() error {
|
||||
}
|
||||
|
||||
type attestationReq struct {
|
||||
TeeNonce [quoteprovider.Nonce]byte
|
||||
TeeNonce [vtpm.SEVNonce]byte
|
||||
VtpmNonce [vtpm.Nonce]byte
|
||||
AttType config.AttestationType
|
||||
AttType attestation.PlatformType
|
||||
}
|
||||
|
||||
type azureAttestationTokenReq struct {
|
||||
tokenNonce [vtpm.Nonce]byte
|
||||
}
|
||||
|
||||
func (req attestationReq) validate() error {
|
||||
switch req.AttType {
|
||||
case config.SNP, config.VTPM, config.SNPvTPM:
|
||||
return validateAttestationType(req.AttType)
|
||||
}
|
||||
|
||||
func (req azureAttestationTokenReq) validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateAttestationType(attType attestation.PlatformType) error {
|
||||
switch attType {
|
||||
case attestation.SNP, attestation.VTPM, attestation.SNPvTPM, attestation.Azure, attestation.TDX:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid attestation type in attestation request")
|
||||
return errors.New("invalid attestation type")
|
||||
}
|
||||
}
|
||||
|
||||
type imaMeasurementsReq struct{}
|
||||
|
||||
func (req imaMeasurementsReq) validate() error {
|
||||
// No request parameters to validate, so no validation logic needed
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,9 +7,18 @@ type algoRes struct{}
|
||||
type dataRes struct{}
|
||||
|
||||
type resultRes struct {
|
||||
File []byte `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
|
||||
File []byte
|
||||
}
|
||||
|
||||
type attestationRes struct {
|
||||
File []byte
|
||||
}
|
||||
|
||||
type imaMeasurementsRes struct {
|
||||
File []byte
|
||||
PCR10 []byte
|
||||
}
|
||||
|
||||
type fetchAttestationTokenRes struct {
|
||||
File []byte
|
||||
}
|
||||
|
||||
+317
-115
@@ -8,11 +8,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
"github.com/go-kit/kit/transport/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -25,213 +26,414 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTEENonceLength = errors.New("malformed report data, expect less or equal to 64 bytes")
|
||||
ErrVTpmNonceLength = errors.New("malformed vTPM nonce, expect less or equal to 32 bytes")
|
||||
ErrTEENonceLength = errors.New("malformed report data, expect less or equal to 64 bytes")
|
||||
ErrVTPMNonceLength = errors.New("malformed vTPM nonce, expect less or equal to 32 bytes")
|
||||
ErrTokenNonceLength = errors.New("malformed token nonce, expect less or equal to 32 bytes")
|
||||
)
|
||||
|
||||
var _ agent.AgentServiceServer = (*grpcServer)(nil)
|
||||
|
||||
type grpcServer struct {
|
||||
algo grpc.Handler
|
||||
data grpc.Handler
|
||||
result grpc.Handler
|
||||
attestation grpc.Handler
|
||||
handlers map[string]grpc.Handler
|
||||
agent.UnimplementedAgentServiceServer
|
||||
}
|
||||
|
||||
type endpointConfig struct {
|
||||
endpoint func(agent.Service) endpoint.Endpoint
|
||||
decodeRequest grpc.DecodeRequestFunc
|
||||
encodeResponse grpc.EncodeResponseFunc
|
||||
}
|
||||
|
||||
// NewServer returns new AgentServiceServer instance.
|
||||
func NewServer(svc agent.Service) agent.AgentServiceServer {
|
||||
// Define endpoint configurations
|
||||
endpoints := map[string]endpointConfig{
|
||||
"algo": {
|
||||
endpoint: algoEndpoint,
|
||||
decodeRequest: decodeAlgoRequest,
|
||||
encodeResponse: encodeAlgoResponse,
|
||||
},
|
||||
"data": {
|
||||
endpoint: dataEndpoint,
|
||||
decodeRequest: decodeDataRequest,
|
||||
encodeResponse: encodeDataResponse,
|
||||
},
|
||||
"result": {
|
||||
endpoint: resultEndpoint,
|
||||
decodeRequest: decodeResultRequest,
|
||||
encodeResponse: encodeResultResponse,
|
||||
},
|
||||
"attestation": {
|
||||
endpoint: attestationEndpoint,
|
||||
decodeRequest: decodeAttestationRequest,
|
||||
encodeResponse: encodeAttestationResponse,
|
||||
},
|
||||
"imaMeasurements": {
|
||||
endpoint: imaMeasurementsEndpoint,
|
||||
decodeRequest: decodeIMAMeasurementsRequest,
|
||||
encodeResponse: encodeIMAMeasurementsResponse,
|
||||
},
|
||||
"azureAttestationToken": {
|
||||
endpoint: azureAttestationTokenEndpoint,
|
||||
decodeRequest: decodeAttestationTokenRequest,
|
||||
encodeResponse: encodeAttestationTokenResponse,
|
||||
},
|
||||
}
|
||||
|
||||
// Create handlers using the configurations
|
||||
handlers := make(map[string]grpc.Handler)
|
||||
for name, config := range endpoints {
|
||||
handlers[name] = grpc.NewServer(
|
||||
config.endpoint(svc),
|
||||
config.decodeRequest,
|
||||
config.encodeResponse,
|
||||
)
|
||||
}
|
||||
|
||||
return &grpcServer{
|
||||
algo: grpc.NewServer(
|
||||
algoEndpoint(svc),
|
||||
decodeAlgoRequest,
|
||||
encodeAlgoResponse,
|
||||
),
|
||||
data: grpc.NewServer(
|
||||
dataEndpoint(svc),
|
||||
decodeDataRequest,
|
||||
encodeDataResponse,
|
||||
),
|
||||
result: grpc.NewServer(
|
||||
resultEndpoint(svc),
|
||||
decodeResultRequest,
|
||||
encodeResultResponse,
|
||||
),
|
||||
attestation: grpc.NewServer(
|
||||
attestationEndpoint(svc),
|
||||
decodeAttestationRequest,
|
||||
encodeAttestationResponse,
|
||||
),
|
||||
handlers: handlers,
|
||||
}
|
||||
}
|
||||
|
||||
func decodeAlgoRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
func decodeAlgoRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*agent.AlgoRequest)
|
||||
|
||||
return algoReq{
|
||||
Algorithm: req.Algorithm,
|
||||
Requirements: req.Requirements,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeAlgoResponse(_ context.Context, response interface{}) (interface{}, error) {
|
||||
func encodeAlgoResponse(_ context.Context, response any) (any, error) {
|
||||
return &agent.AlgoResponse{}, nil
|
||||
}
|
||||
|
||||
func decodeDataRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
func decodeDataRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*agent.DataRequest)
|
||||
|
||||
return dataReq{
|
||||
Dataset: req.Dataset,
|
||||
Filename: req.Filename,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeDataResponse(_ context.Context, response interface{}) (interface{}, error) {
|
||||
func encodeDataResponse(_ context.Context, response any) (any, error) {
|
||||
return &agent.DataResponse{}, nil
|
||||
}
|
||||
|
||||
func decodeResultRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
func decodeResultRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
return resultReq{}, nil
|
||||
}
|
||||
|
||||
func encodeResultResponse(_ context.Context, response interface{}) (interface{}, error) {
|
||||
func encodeResultResponse(_ context.Context, response any) (any, error) {
|
||||
res := response.(resultRes)
|
||||
return &agent.ResultResponse{
|
||||
File: res.File,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeAttestationRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(*agent.AttestationRequest)
|
||||
var reportData [quoteprovider.Nonce]byte
|
||||
var nonce [vtpm.Nonce]byte
|
||||
|
||||
if len(req.TeeNonce) > quoteprovider.Nonce {
|
||||
return nil, ErrTEENonceLength
|
||||
func validateNonce(nonce []byte, maxLen int, target any) error {
|
||||
if len(nonce) > maxLen {
|
||||
switch maxLen {
|
||||
case vtpm.SEVNonce:
|
||||
return ErrTEENonceLength
|
||||
case vtpm.Nonce:
|
||||
return ErrVTPMNonceLength
|
||||
default:
|
||||
return ErrTokenNonceLength
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.VtpmNonce) > vtpm.Nonce {
|
||||
return nil, ErrVTpmNonceLength
|
||||
switch t := target.(type) {
|
||||
case *[vtpm.SEVNonce]byte:
|
||||
copy(t[:], nonce)
|
||||
case *[vtpm.Nonce]byte:
|
||||
copy(t[:], nonce)
|
||||
default:
|
||||
return fmt.Errorf("unsupported target type for nonce validation: %T", target)
|
||||
}
|
||||
|
||||
copy(reportData[:], req.TeeNonce)
|
||||
copy(nonce[:], req.VtpmNonce)
|
||||
return attestationReq{TeeNonce: reportData, VtpmNonce: nonce, AttType: config.AttestationType(req.Type)}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeAttestationResponse(_ context.Context, response interface{}) (interface{}, error) {
|
||||
func decodeAttestationRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*agent.AttestationRequest)
|
||||
var reportData [vtpm.SEVNonce]byte
|
||||
var nonce [vtpm.Nonce]byte
|
||||
|
||||
if err := validateNonce(req.TeeNonce, vtpm.SEVNonce, &reportData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateNonce(req.VtpmNonce, vtpm.Nonce, &nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return attestationReq{
|
||||
TeeNonce: reportData,
|
||||
VtpmNonce: nonce,
|
||||
AttType: attestation.PlatformType(req.Type),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeAttestationResponse(_ context.Context, response any) (any, error) {
|
||||
res := response.(attestationRes)
|
||||
return &agent.AttestationResponse{
|
||||
File: res.File,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Algo implements agent.AgentServiceServer.
|
||||
func (s *grpcServer) Algo(stream agent.AgentService_AlgoServer) error {
|
||||
var algoFile, reqFile []byte
|
||||
func decodeAttestationTokenRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*agent.AttestationTokenRequest)
|
||||
var nonce [vtpm.Nonce]byte
|
||||
|
||||
if err := validateNonce(req.TokenNonce, vtpm.Nonce, &nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return azureAttestationTokenReq{
|
||||
tokenNonce: nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeAttestationTokenResponse(_ context.Context, response any) (any, error) {
|
||||
res := response.(fetchAttestationTokenRes)
|
||||
return &agent.AttestationTokenResponse{
|
||||
File: res.File,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeIMAMeasurementsRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
return imaMeasurementsReq{}, nil
|
||||
}
|
||||
|
||||
func encodeIMAMeasurementsResponse(_ context.Context, response any) (any, error) {
|
||||
res := response.(imaMeasurementsRes)
|
||||
return &agent.IMAMeasurementsResponse{
|
||||
File: res.File,
|
||||
Pcr10: res.PCR10,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) streamingHandler(
|
||||
ctx context.Context,
|
||||
handlerName string,
|
||||
req any,
|
||||
stream any,
|
||||
sendFn func([]byte) error,
|
||||
getFileData func(any) []byte,
|
||||
) error {
|
||||
handler, ok := s.handlers[handlerName]
|
||||
if !ok {
|
||||
return status.Errorf(codes.NotFound, "handler %q not found", handlerName)
|
||||
}
|
||||
|
||||
_, res, err := handler.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileData := getFileData(res)
|
||||
|
||||
// Set file size header
|
||||
if setter, ok := stream.(interface{ SetHeader(metadata.MD) error }); ok {
|
||||
if err := setter.SetHeader(metadata.New(map[string]string{
|
||||
FileSizeKey: fmt.Sprint(len(fileData)),
|
||||
})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the file data
|
||||
return s.streamFileData(bytes.NewBuffer(fileData), sendFn)
|
||||
}
|
||||
|
||||
func (s *grpcServer) streamFileData(buffer *bytes.Buffer, sendFn func([]byte) error) error {
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
algoChunk, err := stream.Recv()
|
||||
n, err := buffer.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
algoFile = append(algoFile, algoChunk.Algorithm...)
|
||||
reqFile = append(reqFile, algoChunk.Requirements...)
|
||||
|
||||
if err := sendFn(buf[:n]); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
_, res, err := s.algo.ServeGRPC(stream.Context(), &agent.AlgoRequest{Algorithm: algoFile, Requirements: reqFile})
|
||||
return nil
|
||||
}
|
||||
|
||||
func receiveStreamingData(getData func() ([]byte, string, error)) ([]byte, string, error) {
|
||||
var data []byte
|
||||
var filename string
|
||||
|
||||
for {
|
||||
chunk, fname, err := getData()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
data = append(data, chunk...)
|
||||
if fname != "" {
|
||||
filename = fname
|
||||
}
|
||||
}
|
||||
return data, filename, nil
|
||||
}
|
||||
|
||||
// Algo implements agent.AgentServiceServer.
|
||||
func (s *grpcServer) Algo(stream agent.AgentService_AlgoServer) error {
|
||||
algoFile, reqFile, err := s.receiveAlgoData(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ar := res.(*agent.AlgoResponse)
|
||||
return stream.SendAndClose(ar)
|
||||
|
||||
_, res, err := s.handlers["algo"].ServeGRPC(stream.Context(), &agent.AlgoRequest{
|
||||
Algorithm: algoFile,
|
||||
Requirements: reqFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return stream.SendAndClose(res.(*agent.AlgoResponse))
|
||||
}
|
||||
|
||||
func (s *grpcServer) receiveAlgoData(stream agent.AgentService_AlgoServer) ([]byte, []byte, error) {
|
||||
var algoFile, reqFile []byte
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
algoFile = append(algoFile, chunk.Algorithm...)
|
||||
reqFile = append(reqFile, chunk.Requirements...)
|
||||
}
|
||||
return algoFile, reqFile, nil
|
||||
}
|
||||
|
||||
// Data implements agent.AgentServiceServer.
|
||||
func (s *grpcServer) Data(stream agent.AgentService_DataServer) error {
|
||||
var dataFile []byte
|
||||
var filename string
|
||||
for {
|
||||
dataChunk, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
dataFile, filename, err := receiveStreamingData(func() ([]byte, string, error) {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
return nil, "", err
|
||||
}
|
||||
dataFile = append(dataFile, dataChunk.Dataset...)
|
||||
filename = dataChunk.Filename
|
||||
}
|
||||
_, res, err := s.data.ServeGRPC(stream.Context(), &agent.DataRequest{Dataset: dataFile, Filename: filename})
|
||||
return chunk.Dataset, chunk.Filename, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ar := res.(*agent.DataResponse)
|
||||
return stream.SendAndClose(ar)
|
||||
|
||||
_, res, err := s.handlers["data"].ServeGRPC(stream.Context(), &agent.DataRequest{
|
||||
Dataset: dataFile,
|
||||
Filename: filename,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return stream.SendAndClose(res.(*agent.DataResponse))
|
||||
}
|
||||
|
||||
func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_ResultServer) error {
|
||||
_, res, err := s.result.ServeGRPC(stream.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rr := res.(*agent.ResultResponse)
|
||||
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
resultBuffer := bytes.NewBuffer(rr.File)
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := resultBuffer.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
if err := stream.Send(&agent.ResultResponse{File: buf[:n]}); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.streamingHandler(
|
||||
stream.Context(),
|
||||
"result",
|
||||
req,
|
||||
stream,
|
||||
func(data []byte) error {
|
||||
return stream.Send(&agent.ResultResponse{File: data})
|
||||
},
|
||||
func(res any) []byte {
|
||||
return res.(*agent.ResultResponse).File
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *grpcServer) Attestation(req *agent.AttestationRequest, stream agent.AgentService_AttestationServer) error {
|
||||
_, res, err := s.attestation.ServeGRPC(stream.Context(), req)
|
||||
return s.streamingHandler(
|
||||
stream.Context(),
|
||||
"attestation",
|
||||
req,
|
||||
stream,
|
||||
func(data []byte) error {
|
||||
return stream.Send(&agent.AttestationResponse{File: data})
|
||||
},
|
||||
func(res any) []byte {
|
||||
return res.(*agent.AttestationResponse).File
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *grpcServer) IMAMeasurements(req *agent.IMAMeasurementsRequest, stream agent.AgentService_IMAMeasurementsServer) error {
|
||||
_, res, err := s.handlers["imaMeasurements"].ServeGRPC(stream.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rr := res.(*agent.AttestationResponse)
|
||||
rr := res.(*agent.IMAMeasurementsResponse)
|
||||
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{FileSizeKey: fmt.Sprint(len(rr.File))})); err != nil {
|
||||
if err := stream.SetHeader(metadata.New(map[string]string{
|
||||
FileSizeKey: strconv.Itoa(len(rr.File)),
|
||||
})); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
|
||||
attestationBuffer := bytes.NewBuffer(rr.File)
|
||||
return s.streamDualBuffers(
|
||||
bytes.NewBuffer(rr.File),
|
||||
bytes.NewBuffer(rr.Pcr10),
|
||||
func(fileData, pcr10Data []byte) error {
|
||||
return stream.Send(&agent.IMAMeasurementsResponse{
|
||||
File: fileData,
|
||||
Pcr10: pcr10Data,
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
func (s *grpcServer) AzureAttestationToken(ctx context.Context, req *agent.AttestationTokenRequest) (*agent.AttestationTokenResponse, error) {
|
||||
_, res, err := s.handlers["azureAttestationToken"].ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rr, ok := res.(*agent.AttestationTokenResponse)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "failed to cast response to AttestationTokenResponse")
|
||||
}
|
||||
|
||||
return rr, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) streamDualBuffers(
|
||||
buf1, buf2 *bytes.Buffer,
|
||||
sendFn func([]byte, []byte) error,
|
||||
) error {
|
||||
buff1 := make([]byte, bufferSize)
|
||||
buff2 := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := attestationBuffer.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
n1, err1 := buf1.Read(buff1)
|
||||
if err1 != nil && err1 != io.EOF {
|
||||
return status.Error(codes.Internal, err1.Error())
|
||||
}
|
||||
|
||||
if err := stream.Send(&agent.AttestationResponse{File: buf[:n]}); err != nil {
|
||||
n2, err2 := buf2.Read(buff2)
|
||||
if err2 != nil && err2 != io.EOF {
|
||||
return status.Error(codes.Internal, err2.Error())
|
||||
}
|
||||
|
||||
if n1 == 0 && err1 == io.EOF && n2 == 0 && err2 == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if err := sendFn(buff1[:n1], buff2[:n2]); err != nil {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+301
-20
@@ -11,8 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -68,8 +67,9 @@ func (m *MockAgentService_ResultServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_ResultServer) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
func (m *MockAgentService_ResultServer) SetHeader(md metadata.MD) error {
|
||||
args := m.Called(md)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_ResultServer) Send(resp *agent.ResultResponse) error {
|
||||
@@ -92,8 +92,46 @@ func (m *MockAgentService_AttestationServer) Send(resp *agent.AttestationRespons
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_AttestationServer) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
func (m *MockAgentService_AttestationServer) SetHeader(md metadata.MD) error {
|
||||
args := m.Called(md)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockAgentService_IMAMeasurementsServer struct {
|
||||
grpc.ServerStream
|
||||
mock.Mock
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *MockAgentService_IMAMeasurementsServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_IMAMeasurementsServer) Send(resp *agent.IMAMeasurementsResponse) error {
|
||||
args := m.Called(resp)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_IMAMeasurementsServer) SetHeader(md metadata.MD) error {
|
||||
args := m.Called(md)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
grpcServer, ok := server.(*grpcServer)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, grpcServer.handlers)
|
||||
assert.Len(t, grpcServer.handlers, 6) // Should have 6 handlers
|
||||
|
||||
// Check that all expected handlers are present
|
||||
expectedHandlers := []string{"algo", "data", "result", "attestation", "imaMeasurements", "azureAttestationToken"}
|
||||
for _, handler := range expectedHandlers {
|
||||
assert.Contains(t, grpcServer.handlers, handler)
|
||||
assert.NotNil(t, grpcServer.handlers[handler])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlgo(t *testing.T) {
|
||||
@@ -102,8 +140,8 @@ func TestAlgo(t *testing.T) {
|
||||
|
||||
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF)
|
||||
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil)
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF).Once()
|
||||
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil).Once()
|
||||
|
||||
mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo"), Requirements: []byte("req")}).Return(nil)
|
||||
|
||||
@@ -114,14 +152,33 @@ func TestAlgo(t *testing.T) {
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAlgoWithMultipleChunks(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("2"), Requirements: []byte("2")}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF).Once()
|
||||
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil).Once()
|
||||
|
||||
mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo2"), Requirements: []byte("req2")}).Return(nil)
|
||||
|
||||
err := server.Algo(mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_DataServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF)
|
||||
mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil)
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF).Once()
|
||||
mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil).Once()
|
||||
|
||||
mockService.On("Data", context.Background(), agent.Dataset{Dataset: []byte("data"), Filename: "test.txt"}).Return(nil)
|
||||
|
||||
@@ -136,9 +193,18 @@ func TestResult(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
resultData := []byte("result data")
|
||||
mockStream := &MockAgentService_ResultServer{ctx: context.Background()}
|
||||
mockService.On("Result", mock.Anything).Return([]byte("result data"), nil)
|
||||
mockStream.On("Send", mock.AnythingOfType("*agent.ResultResponse")).Return(nil)
|
||||
|
||||
// Mock the SetHeader call
|
||||
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
|
||||
|
||||
// Mock the Send call - it should be called with the result data
|
||||
mockStream.On("Send", mock.MatchedBy(func(resp *agent.ResultResponse) bool {
|
||||
return len(resp.File) > 0
|
||||
})).Return(nil).Once()
|
||||
|
||||
mockService.On("Result", mock.Anything).Return(resultData, nil)
|
||||
|
||||
err := server.Result(&agent.ResultRequest{}, mockStream)
|
||||
assert.NoError(t, err)
|
||||
@@ -151,18 +217,135 @@ func TestAttestation(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
attestationData := []byte("attestation data")
|
||||
mockStream := &MockAgentService_AttestationServer{ctx: context.Background()}
|
||||
mockStream.On("Send", mock.AnythingOfType("*agent.AttestationResponse")).Return(nil)
|
||||
|
||||
reportData := [quoteprovider.Nonce]byte{}
|
||||
// Mock the SetHeader call
|
||||
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
|
||||
|
||||
// Mock the Send call
|
||||
mockStream.On("Send", mock.MatchedBy(func(resp *agent.AttestationResponse) bool {
|
||||
return len(resp.File) > 0
|
||||
})).Return(nil).Once()
|
||||
|
||||
reportData := [vtpm.SEVNonce]byte{}
|
||||
vtpmNonce := [vtpm.Nonce]byte{}
|
||||
attestationType := config.SNP
|
||||
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return([]byte("attestation data"), nil)
|
||||
attestationType := attestation.SNP
|
||||
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return(attestationData, nil)
|
||||
|
||||
err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:]}, mockStream)
|
||||
err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:], Type: int32(attestationType)}, mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockService.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestIMAMeasurements(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
imaData := []byte("ima data")
|
||||
pcr10Data := []byte("pcr10 data")
|
||||
|
||||
mockStream := &MockAgentService_IMAMeasurementsServer{ctx: context.Background()}
|
||||
|
||||
// Mock the SetHeader call
|
||||
mockStream.On("SetHeader", mock.AnythingOfType("metadata.MD")).Return(nil).Once()
|
||||
|
||||
// Mock the Send call
|
||||
mockStream.On("Send", mock.MatchedBy(func(resp *agent.IMAMeasurementsResponse) bool {
|
||||
return len(resp.File) > 0 || len(resp.Pcr10) > 0
|
||||
})).Return(nil).Once()
|
||||
|
||||
mockService.On("IMAMeasurements", mock.Anything).Return(imaData, pcr10Data, nil)
|
||||
|
||||
err := server.IMAMeasurements(&agent.IMAMeasurementsRequest{}, mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockService.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAttestationToken(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
attestationData := []byte("attestation token data")
|
||||
vtpmNonce := [vtpm.Nonce]byte{}
|
||||
attestationType := attestation.SNP
|
||||
|
||||
mockService.On("AzureAttestationToken", mock.Anything, vtpmNonce).Return(attestationData, nil)
|
||||
|
||||
resp, err := server.AzureAttestationToken(context.Background(), &agent.AttestationTokenRequest{
|
||||
TokenNonce: vtpmNonce[:],
|
||||
Type: int32(attestationType),
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, attestationData, resp.File)
|
||||
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestValidateNonce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nonce []byte
|
||||
maxLen int
|
||||
shouldError bool
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "valid TEE nonce",
|
||||
nonce: make([]byte, vtpm.SEVNonce),
|
||||
maxLen: vtpm.SEVNonce,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid vTPM nonce",
|
||||
nonce: make([]byte, vtpm.Nonce),
|
||||
maxLen: vtpm.Nonce,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "TEE nonce too long",
|
||||
nonce: make([]byte, vtpm.SEVNonce+1),
|
||||
maxLen: vtpm.SEVNonce,
|
||||
shouldError: true,
|
||||
expectedErr: ErrTEENonceLength,
|
||||
},
|
||||
{
|
||||
name: "vTPM nonce too long",
|
||||
nonce: make([]byte, vtpm.Nonce+1),
|
||||
maxLen: vtpm.Nonce,
|
||||
shouldError: true,
|
||||
expectedErr: ErrVTPMNonceLength,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.maxLen == vtpm.SEVNonce {
|
||||
var target [vtpm.SEVNonce]byte
|
||||
err := validateNonce(tt.nonce, tt.maxLen, &target)
|
||||
if tt.shouldError {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.expectedErr, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
} else {
|
||||
var target [vtpm.Nonce]byte
|
||||
err := validateNonce(tt.nonce, tt.maxLen, &target)
|
||||
if tt.shouldError {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.expectedErr, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeAlgoRequest(t *testing.T) {
|
||||
@@ -204,11 +387,38 @@ func TestEncodeResultResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDecodeAttestationRequest(t *testing.T) {
|
||||
nonce := [quoteprovider.Nonce]byte{}
|
||||
req := &agent.AttestationRequest{TeeNonce: nonce[:]}
|
||||
teeNonce := make([]byte, vtpm.SEVNonce)
|
||||
vtpmNonce := make([]byte, vtpm.Nonce)
|
||||
|
||||
req := &agent.AttestationRequest{
|
||||
TeeNonce: teeNonce,
|
||||
VtpmNonce: vtpmNonce,
|
||||
Type: int32(attestation.SNP),
|
||||
}
|
||||
|
||||
decoded, err := decodeAttestationRequest(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, attestationReq{TeeNonce: nonce}, decoded)
|
||||
|
||||
decodedReq := decoded.(attestationReq)
|
||||
assert.Equal(t, attestation.SNP, decodedReq.AttType)
|
||||
}
|
||||
|
||||
func TestDecodeAttestationRequestWithInvalidNonce(t *testing.T) {
|
||||
// Test with TEE nonce too long
|
||||
teeNonce := make([]byte, vtpm.SEVNonce+1)
|
||||
req := &agent.AttestationRequest{TeeNonce: teeNonce}
|
||||
|
||||
_, err := decodeAttestationRequest(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrTEENonceLength, err)
|
||||
|
||||
// Test with vTPM nonce too long
|
||||
vtpmNonce := make([]byte, vtpm.Nonce+1)
|
||||
req = &agent.AttestationRequest{VtpmNonce: vtpmNonce}
|
||||
|
||||
_, err = decodeAttestationRequest(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrVTPMNonceLength, err)
|
||||
}
|
||||
|
||||
func TestEncodeAttestationResponse(t *testing.T) {
|
||||
@@ -216,3 +426,74 @@ func TestEncodeAttestationResponse(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.AttestationResponse{File: []byte("attestation")}, encoded)
|
||||
}
|
||||
|
||||
func TestDecodeAttestationTokenRequest(t *testing.T) {
|
||||
tokenNonce := make([]byte, vtpm.Nonce)
|
||||
req := &agent.AttestationTokenRequest{
|
||||
TokenNonce: tokenNonce,
|
||||
Type: int32(attestation.SNP),
|
||||
}
|
||||
|
||||
_, err := decodeAttestationTokenRequest(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDecodeAttestationTokenRequestWithInvalidNonce(t *testing.T) {
|
||||
// Test with token nonce too long
|
||||
tokenNonce := make([]byte, vtpm.Nonce+1)
|
||||
req := &agent.AttestationTokenRequest{TokenNonce: tokenNonce}
|
||||
|
||||
_, err := decodeAttestationTokenRequest(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrVTPMNonceLength, err)
|
||||
}
|
||||
|
||||
func TestEncodeAttestationTokenResponse(t *testing.T) {
|
||||
encoded, err := encodeAttestationTokenResponse(context.Background(), fetchAttestationTokenRes{File: []byte("attestation")})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.AttestationTokenResponse{File: []byte("attestation")}, encoded)
|
||||
}
|
||||
|
||||
func TestDecodeIMAMeasurementsRequest(t *testing.T) {
|
||||
decoded, err := decodeIMAMeasurementsRequest(context.Background(), &agent.IMAMeasurementsRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, imaMeasurementsReq{}, decoded)
|
||||
}
|
||||
|
||||
func TestEncodeIMAMeasurementsResponse(t *testing.T) {
|
||||
encoded, err := encodeIMAMeasurementsResponse(context.Background(), imaMeasurementsRes{
|
||||
File: []byte("ima"),
|
||||
PCR10: []byte("pcr10"),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.IMAMeasurementsResponse{
|
||||
File: []byte("ima"),
|
||||
Pcr10: []byte("pcr10"),
|
||||
}, encoded)
|
||||
}
|
||||
|
||||
func TestAlgoWithStreamError(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, assert.AnError).Once()
|
||||
|
||||
err := server.Algo(mockStream)
|
||||
assert.Error(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDataWithStreamError(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_DataServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{}, assert.AnError).Once()
|
||||
|
||||
err := server.Data(mockStream)
|
||||
assert.Error(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
+28
-4
@@ -2,7 +2,6 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package api
|
||||
|
||||
@@ -13,8 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
@@ -106,7 +104,7 @@ func (lm *loggingMiddleware) Result(ctx context.Context) (response []byte, err e
|
||||
return lm.svc.Result(ctx)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) (response []byte, err error) {
|
||||
func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method Attestation took %s to complete", time.Since(begin))
|
||||
if err != nil {
|
||||
@@ -118,3 +116,29 @@ func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [quotep
|
||||
|
||||
return lm.svc.Attestation(ctx, reportData, nonce, attType)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) IMAMeasurements(ctx context.Context) (file []byte, pcr10 []byte, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method IMAMeasurements took %s to complete", time.Since(begin))
|
||||
if err != nil {
|
||||
lm.logger.Warn(fmt.Sprintf("%s with error: %s", message, err))
|
||||
return
|
||||
}
|
||||
lm.logger.Info(fmt.Sprintf("%s without errors", message))
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.IMAMeasurements(ctx)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) (response []byte, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method AzureAttestationToken took %s to complete", time.Since(begin))
|
||||
if err != nil {
|
||||
lm.logger.Warn(fmt.Sprintf("%s with error: %s", message, err))
|
||||
return
|
||||
}
|
||||
lm.logger.Info(fmt.Sprintf("%s without errors", message))
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.AzureAttestationToken(ctx, nonce)
|
||||
}
|
||||
|
||||
+20
-4
@@ -2,7 +2,6 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package api
|
||||
|
||||
@@ -12,8 +11,7 @@ import (
|
||||
|
||||
"github.com/go-kit/kit/metrics"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
@@ -92,7 +90,7 @@ func (ms *metricsMiddleware) Result(ctx context.Context) ([]byte, error) {
|
||||
return ms.svc.Result(ctx)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType config.AttestationType) ([]byte, error) {
|
||||
func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "attestation").Add(1)
|
||||
ms.latency.With("method", "attestation").Observe(time.Since(begin).Seconds())
|
||||
@@ -100,3 +98,21 @@ func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [quotep
|
||||
|
||||
return ms.svc.Attestation(ctx, reportData, nonce, attType)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "attestation_token").Add(1)
|
||||
ms.latency.With("method", "attestation_token").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.AzureAttestationToken(ctx, nonce)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "imameasurements").Add(1)
|
||||
ms.latency.With("method", "imameasurements").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.IMAMeasurements(ctx)
|
||||
}
|
||||
|
||||
+3
-3
@@ -41,9 +41,9 @@ type Authenticator interface {
|
||||
}
|
||||
|
||||
type service struct {
|
||||
resultConsumers []interface{}
|
||||
datasetProviders []interface{}
|
||||
algorithmProvider interface{}
|
||||
resultConsumers []any
|
||||
datasetProviders []any
|
||||
algorithmProvider any
|
||||
}
|
||||
|
||||
func New(manifest agent.Computation) (Authenticator, error) {
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
manifest := agent.Computation{
|
||||
ResultConsumers: []agent.ResultConsumer{{UserKey: resultConsumerPubKey}},
|
||||
Datasets: []agent.Dataset{{UserKey: dataProviderPubKey}},
|
||||
Algorithm: agent.Algorithm{UserKey: algorithmProviderPubKey},
|
||||
Algorithm: &agent.Algorithm{UserKey: algorithmProviderPubKey},
|
||||
}
|
||||
|
||||
auth, err := New(manifest)
|
||||
|
||||
@@ -1,18 +1,33 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
auth "github.com/ultravioletrs/cocos/agent/auth"
|
||||
"context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
)
|
||||
|
||||
// NewAuthenticator creates a new instance of Authenticator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAuthenticator(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Authenticator {
|
||||
mock := &Authenticator{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Authenticator is an autogenerated mock type for the Authenticator type
|
||||
type Authenticator struct {
|
||||
mock.Mock
|
||||
@@ -26,9 +41,9 @@ func (_m *Authenticator) EXPECT() *Authenticator_Expecter {
|
||||
return &Authenticator_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AuthenticateUser provides a mock function with given fields: ctx, role
|
||||
func (_m *Authenticator) AuthenticateUser(ctx context.Context, role auth.UserRole) (context.Context, error) {
|
||||
ret := _m.Called(ctx, role)
|
||||
// AuthenticateUser provides a mock function for the type Authenticator
|
||||
func (_mock *Authenticator) AuthenticateUser(ctx context.Context, role auth.UserRole) (context.Context, error) {
|
||||
ret := _mock.Called(ctx, role)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AuthenticateUser")
|
||||
@@ -36,23 +51,21 @@ func (_m *Authenticator) AuthenticateUser(ctx context.Context, role auth.UserRol
|
||||
|
||||
var r0 context.Context
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, auth.UserRole) (context.Context, error)); ok {
|
||||
return rf(ctx, role)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, auth.UserRole) (context.Context, error)); ok {
|
||||
return returnFunc(ctx, role)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, auth.UserRole) context.Context); ok {
|
||||
r0 = rf(ctx, role)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, auth.UserRole) context.Context); ok {
|
||||
r0 = returnFunc(ctx, role)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(context.Context)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, auth.UserRole) error); ok {
|
||||
r1 = rf(ctx, role)
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, auth.UserRole) error); ok {
|
||||
r1 = returnFunc(ctx, role)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -70,31 +83,28 @@ func (_e *Authenticator_Expecter) AuthenticateUser(ctx interface{}, role interfa
|
||||
|
||||
func (_c *Authenticator_AuthenticateUser_Call) Run(run func(ctx context.Context, role auth.UserRole)) *Authenticator_AuthenticateUser_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(auth.UserRole))
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 auth.UserRole
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(auth.UserRole)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Authenticator_AuthenticateUser_Call) Return(_a0 context.Context, _a1 error) *Authenticator_AuthenticateUser_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
func (_c *Authenticator_AuthenticateUser_Call) Return(context1 context.Context, err error) *Authenticator_AuthenticateUser_Call {
|
||||
_c.Call.Return(context1, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Authenticator_AuthenticateUser_Call) RunAndReturn(run func(context.Context, auth.UserRole) (context.Context, error)) *Authenticator_AuthenticateUser_Call {
|
||||
func (_c *Authenticator_AuthenticateUser_Call) RunAndReturn(run func(ctx context.Context, role auth.UserRole) (context.Context, error)) *Authenticator_AuthenticateUser_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewAuthenticator creates a new instance of Authenticator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAuthenticator(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Authenticator {
|
||||
mock := &Authenticator{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
+43
-10
@@ -13,7 +13,6 @@ import (
|
||||
var _ fmt.Stringer = (*Datasets)(nil)
|
||||
|
||||
type AgentConfig struct {
|
||||
Port string `json:"port,omitempty"`
|
||||
CertFile string `json:"cert_file,omitempty"`
|
||||
KeyFile string `json:"server_key,omitempty"`
|
||||
ServerCAFile string `json:"server_ca_file,omitempty"`
|
||||
@@ -21,12 +20,39 @@ type AgentConfig struct {
|
||||
AttestedTls bool `json:"attested_tls,omitempty"`
|
||||
}
|
||||
|
||||
// ResourceSource specifies the location of a remote encrypted resource.
|
||||
type ResourceSource struct {
|
||||
// Type is the type of resource source.
|
||||
// Supported values: "oci-image", "s3", "gcs", "https", "http"
|
||||
Type string `json:"type,omitempty"`
|
||||
// URL is the location of the resource.
|
||||
// Examples:
|
||||
// - OCI: "docker://registry/repo:tag"
|
||||
// - S3: "s3://bucket/key"
|
||||
// - GCS: "gs://bucket/key"
|
||||
// - HTTPS: "https://host/path/to/file"
|
||||
// - HTTP: "http://host/path/to/file"
|
||||
URL string `json:"url,omitempty"`
|
||||
// KBSResourcePath is the path to the decryption key in KBS (e.g., "default/key/my-key")
|
||||
KBSResourcePath string `json:"kbs_resource_path,omitempty"`
|
||||
// Encrypted indicates whether the resource is encrypted and requires KBS
|
||||
Encrypted bool `json:"encrypted,omitempty"`
|
||||
}
|
||||
|
||||
// KBSConfig holds configuration for Key Broker Service.
|
||||
type KBSConfig struct {
|
||||
// URL is the KBS endpoint (e.g., "https://kbs.example.com")
|
||||
URL string `json:"url,omitempty"`
|
||||
// Enabled indicates whether to use KBS for key retrieval
|
||||
Enabled bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
type Computation struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Datasets Datasets `json:"datasets,omitempty"`
|
||||
Algorithm Algorithm `json:"algorithm,omitempty"`
|
||||
Algorithm *Algorithm `json:"algorithm,omitempty"`
|
||||
ResultConsumers []ResultConsumer `json:"result_consumers,omitempty"`
|
||||
}
|
||||
|
||||
@@ -43,19 +69,26 @@ func (d *Datasets) String() string {
|
||||
}
|
||||
|
||||
type Dataset struct {
|
||||
Dataset []byte `json:"-"`
|
||||
Hash [32]byte `json:"hash,omitempty"`
|
||||
UserKey []byte `json:"user_key,omitempty"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Dataset []byte `json:"-"`
|
||||
Hash [32]byte `json:"hash,omitempty"`
|
||||
UserKey []byte `json:"user_key,omitempty"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Source *ResourceSource `json:"source,omitempty"` // Optional remote source
|
||||
Decompress bool `json:"decompress,omitempty"`
|
||||
KBS *KBSConfig `json:"kbs,omitempty"`
|
||||
}
|
||||
|
||||
type Datasets []Dataset
|
||||
|
||||
type Algorithm struct {
|
||||
Algorithm []byte `json:"-"`
|
||||
Hash [32]byte `json:"hash,omitempty"`
|
||||
UserKey []byte `json:"user_key,omitempty"`
|
||||
Requirements []byte `json:"-"`
|
||||
Algorithm []byte `json:"-"`
|
||||
Hash [32]byte `json:"hash,omitempty"`
|
||||
UserKey []byte `json:"user_key,omitempty"`
|
||||
Requirements []byte `json:"-"`
|
||||
Source *ResourceSource `json:"source,omitempty"` // Optional remote source
|
||||
AlgoType string `json:"algo_type,omitempty"`
|
||||
AlgoArgs []string `json:"algo_args,omitempty"`
|
||||
KBS *KBSConfig `json:"kbs,omitempty"`
|
||||
}
|
||||
|
||||
type ManifestIndexKey struct{}
|
||||
|
||||
@@ -105,16 +105,15 @@ func TestDecompressToContext(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentConfigJSON(t *testing.T) {
|
||||
config := AgentConfig{
|
||||
Port: "8080",
|
||||
cfg := AgentConfig{
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server_ca.pem",
|
||||
ClientCAFile: "client_ca.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
ClientCAFile: "client-ca.pem",
|
||||
AttestedTls: true,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(config)
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal AgentConfig: %v", err)
|
||||
}
|
||||
@@ -125,7 +124,7 @@ func TestAgentConfigJSON(t *testing.T) {
|
||||
t.Fatalf("Failed to unmarshal AgentConfig: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(config, unmarshaledConfig) {
|
||||
t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, config)
|
||||
if !reflect.DeepEqual(cfg, unmarshaledConfig) {
|
||||
t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
+103
-15
@@ -4,6 +4,8 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,7 +15,10 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/server"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/ingress"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -21,12 +26,11 @@ import (
|
||||
const (
|
||||
reconnectInterval = 5 * time.Second
|
||||
sendTimeout = 5 * time.Second
|
||||
pendingMsgFile = "pending_messages.json"
|
||||
)
|
||||
|
||||
var (
|
||||
errCorruptedManifest = errors.New("received manifest may be corrupted")
|
||||
errUnknonwMessageType = errors.New("unknown message type")
|
||||
errUnknownMessageType = errors.New("unknown message type")
|
||||
)
|
||||
|
||||
type PendingMessage struct {
|
||||
@@ -42,13 +46,14 @@ type CVMSClient struct {
|
||||
logger *slog.Logger
|
||||
runReqManager *runRequestManager
|
||||
sp server.AgentServer
|
||||
ingressProxy ingress.ProxyServer
|
||||
storage storage.Storage
|
||||
reconnectFn func(context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error)
|
||||
grpcClient pkggrpc.Client
|
||||
reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error)
|
||||
grpcClient grpc.Client
|
||||
}
|
||||
|
||||
// NewClient returns new gRPC client instance.
|
||||
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, storageDir string, reconnectFn func(context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error), grpcClient pkggrpc.Client) (*CVMSClient, error) {
|
||||
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, ingressProxy ingress.ProxyServer, storageDir string, reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error), grpcClient grpc.Client) (*CVMSClient, error) {
|
||||
store, err := storage.NewFileStorage(storageDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -61,6 +66,7 @@ func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueu
|
||||
logger: logger,
|
||||
runReqManager: newRunRequestManager(),
|
||||
sp: sp,
|
||||
ingressProxy: ingressProxy,
|
||||
storage: store,
|
||||
reconnectFn: reconnectFn,
|
||||
grpcClient: grpcClient,
|
||||
@@ -184,7 +190,7 @@ func (client *CVMSClient) processIncomingMessage(ctx context.Context, req *cvms.
|
||||
}
|
||||
client.mu.Unlock()
|
||||
default:
|
||||
return errUnknonwMessageType
|
||||
return errUnknownMessageType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -203,14 +209,17 @@ func (client *CVMSClient) handleAgentStateReq(mes *cvms.ServerStreamMessage_Agen
|
||||
}
|
||||
|
||||
func (client *CVMSClient) handleRunReqChunks(ctx context.Context, msg *cvms.ServerStreamMessage_RunReqChunks) error {
|
||||
client.logger.Debug("Received RunReq chunk", "id", msg.RunReqChunks.Id, "size", len(msg.RunReqChunks.Data), "isLast", msg.RunReqChunks.IsLast)
|
||||
buffer, complete := client.runReqManager.addChunk(msg.RunReqChunks.Id, msg.RunReqChunks.Data, msg.RunReqChunks.IsLast)
|
||||
|
||||
if complete {
|
||||
client.logger.Info("Received complete computation run request", "id", msg.RunReqChunks.Id, "totalSize", len(buffer))
|
||||
var runReq cvms.ComputationRunReq
|
||||
if err := proto.Unmarshal(buffer, &runReq); err != nil {
|
||||
return errors.Wrap(err, errCorruptedManifest)
|
||||
}
|
||||
|
||||
client.logger.Info("Starting computation execution", "computationId", runReq.Id, "name", runReq.Name)
|
||||
go client.executeRun(ctx, &runReq)
|
||||
}
|
||||
|
||||
@@ -225,17 +234,50 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
}
|
||||
|
||||
if runReq.Algorithm != nil {
|
||||
ac.Algorithm = agent.Algorithm{
|
||||
Hash: [32]byte(runReq.Algorithm.Hash),
|
||||
UserKey: runReq.Algorithm.UserKey,
|
||||
ac.Algorithm = &agent.Algorithm{
|
||||
Hash: [32]byte(runReq.Algorithm.Hash),
|
||||
UserKey: runReq.Algorithm.UserKey,
|
||||
AlgoType: runReq.Algorithm.AlgoType,
|
||||
}
|
||||
// Copy remote source if configured
|
||||
if runReq.Algorithm.Source != nil {
|
||||
ac.Algorithm.Source = &agent.ResourceSource{
|
||||
URL: runReq.Algorithm.Source.Url,
|
||||
KBSResourcePath: runReq.Algorithm.Source.KbsResourcePath,
|
||||
Encrypted: runReq.Algorithm.Source.Encrypted,
|
||||
}
|
||||
}
|
||||
ac.Algorithm.AlgoArgs = runReq.Algorithm.AlgoArgs
|
||||
if runReq.Algorithm.Kbs != nil {
|
||||
ac.Algorithm.KBS = &agent.KBSConfig{
|
||||
URL: runReq.Algorithm.Kbs.Url,
|
||||
Enabled: runReq.Algorithm.Kbs.Enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ds := range runReq.Datasets {
|
||||
ac.Datasets = append(ac.Datasets, agent.Dataset{
|
||||
Hash: [32]byte(ds.Hash),
|
||||
UserKey: ds.UserKey,
|
||||
})
|
||||
dataset := agent.Dataset{
|
||||
Hash: [32]byte(ds.Hash),
|
||||
UserKey: ds.UserKey,
|
||||
Filename: ds.Filename,
|
||||
}
|
||||
// Copy remote source if configured
|
||||
if ds.Source != nil {
|
||||
dataset.Source = &agent.ResourceSource{
|
||||
URL: ds.Source.Url,
|
||||
KBSResourcePath: ds.Source.KbsResourcePath,
|
||||
Encrypted: ds.Source.Encrypted,
|
||||
}
|
||||
}
|
||||
dataset.Decompress = ds.Decompress
|
||||
if ds.Kbs != nil {
|
||||
dataset.KBS = &agent.KBSConfig{
|
||||
URL: ds.Kbs.Url,
|
||||
Enabled: ds.Kbs.Enabled,
|
||||
}
|
||||
}
|
||||
ac.Datasets = append(ac.Datasets, dataset)
|
||||
}
|
||||
|
||||
for _, rc := range runReq.ResultConsumers {
|
||||
@@ -244,11 +286,22 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
})
|
||||
}
|
||||
|
||||
// Check if the agent is in the correct state to initialize a new computation.
|
||||
// If the agent is already processing this computation (e.g., after a reconnection),
|
||||
// skip initialization to avoid state errors.
|
||||
currentState := client.svc.State()
|
||||
if currentState != "ReceivingManifest" {
|
||||
client.logger.Info("Agent already processing computation, skipping initialization", "state", currentState, "computationId", runReq.Id)
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.svc.InitComputation(ctx, ac); err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ccPlatform := attestation.CCPlatform()
|
||||
|
||||
client.mu.Lock()
|
||||
defer client.mu.Unlock()
|
||||
|
||||
@@ -263,7 +316,6 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
}
|
||||
|
||||
if err := client.sp.Start(agent.AgentConfig{
|
||||
Port: runReq.AgentConfig.Port,
|
||||
CertFile: runReq.AgentConfig.CertFile,
|
||||
KeyFile: runReq.AgentConfig.KeyFile,
|
||||
ServerCAFile: runReq.AgentConfig.ServerCaFile,
|
||||
@@ -274,6 +326,36 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
runRes.RunRes.Error = err.Error()
|
||||
}
|
||||
|
||||
// Start ingress proxy if available
|
||||
if client.ingressProxy != nil {
|
||||
if err := client.ingressProxy.Start(
|
||||
ingress.AgentConfigToProxyConfig(agent.AgentConfig{
|
||||
CertFile: runReq.AgentConfig.CertFile,
|
||||
KeyFile: runReq.AgentConfig.KeyFile,
|
||||
ServerCAFile: runReq.AgentConfig.ServerCaFile,
|
||||
ClientCAFile: runReq.AgentConfig.ClientCaFile,
|
||||
AttestedTls: runReq.AgentConfig.AttestedTls,
|
||||
}),
|
||||
ingress.ComputationToProxyContext(ac),
|
||||
); err != nil {
|
||||
client.logger.Warn(fmt.Sprintf("failed to start ingress proxy: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if ccPlatform == attestation.Azure || ccPlatform == attestation.SNPvTPM {
|
||||
cmpJson, err := json.Marshal(ac)
|
||||
if err != nil {
|
||||
client.logger.Error(err.Error())
|
||||
return
|
||||
}
|
||||
if err = vtpm.ExtendPCR(vtpm.PCR16, cmpJson); err != nil {
|
||||
client.logger.Error(err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
client.sendMessage(&cvms.ClientStreamMessage{Message: runRes})
|
||||
}
|
||||
|
||||
@@ -291,6 +373,12 @@ func (client *CVMSClient) handleStopComputation(ctx context.Context, mes *cvms.S
|
||||
if err := client.sp.Stop(); err != nil {
|
||||
msg.StopComputationRes.Message = err.Error()
|
||||
}
|
||||
// Stop ingress proxy if available
|
||||
if client.ingressProxy != nil {
|
||||
if err := client.ingressProxy.Stop(); err != nil {
|
||||
client.logger.Warn(fmt.Sprintf("failed to stop ingress proxy: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
client.mu.Unlock()
|
||||
|
||||
client.sendMessage(&cvms.ClientStreamMessage{Message: msg})
|
||||
|
||||
@@ -10,11 +10,15 @@ import (
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage"
|
||||
servermocks "github.com/ultravioletrs/cocos/agent/cvms/server/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
clientmocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/ingress"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -34,16 +38,31 @@ func (m *mockStream) Send(msg *cvms.ClientStreamMessage) error {
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// mockIngressProxy is a mock implementation of the ingress proxy.
|
||||
type mockIngressProxy struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockIngressProxy) Start(config ingress.ProxyConfig, ctx ingress.ProxyContext) error {
|
||||
args := m.Called(config, ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockIngressProxy) Stop() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestManagerClient_Process(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer)
|
||||
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Stop computation",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer) {
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &cvms.StopComputation{},
|
||||
@@ -58,7 +77,7 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Run request chunks",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer) {
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{},
|
||||
@@ -69,9 +88,37 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Agent state request",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_AgentStateReq{
|
||||
AgentStateReq: &cvms.AgentStateReq{
|
||||
Id: "test-agent",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil)
|
||||
mockSvc.On("State").Return("test-state")
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Disconnect request",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_DisconnectReq{},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil)
|
||||
grpcClient.On("Close").Return(nil)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Receive error",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer) {
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{}, assert.AnError)
|
||||
},
|
||||
expectError: true,
|
||||
@@ -92,13 +139,13 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
tc.setupMocks(mockStream, mockSvc, mockServerSvc)
|
||||
tc.setupMocks(mockStream, mockSvc, mockServerSvc, grpcClient)
|
||||
|
||||
err = client.Process(ctx, cancel)
|
||||
|
||||
@@ -122,11 +169,24 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
Datasets: []*cvms.Dataset{
|
||||
{
|
||||
Hash: sha3.New256().Sum([]byte("test-dataset")),
|
||||
},
|
||||
},
|
||||
Algorithm: &cvms.Algorithm{
|
||||
Hash: sha3.New256().Sum([]byte("test-algorithm")),
|
||||
},
|
||||
ResultConsumers: []*cvms.ResultConsumer{
|
||||
{
|
||||
UserKey: []byte("test-consumer"),
|
||||
},
|
||||
},
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
@@ -145,6 +205,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("State").Return("ReceivingManifest")
|
||||
mockSvc.On("InitComputation", mock.Anything, mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Start", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
@@ -174,7 +235,7 @@ func TestManagerClient_handleStopComputation(t *testing.T) {
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stopReq := &cvms.ServerStreamMessage_StopComputation{
|
||||
@@ -213,3 +274,381 @@ func TestManagerClient_timeoutRequest(t *testing.T) {
|
||||
|
||||
assert.Len(t, rm.requests, 0)
|
||||
}
|
||||
|
||||
// TestManagerClient_sendPendingMessages tests sending pending messages on reconnection.
|
||||
func TestManagerClient_sendPendingMessages(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a pending message to storage
|
||||
testMsg := &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = client.storage.Add(testMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Mock successful send
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Load and send pending messages
|
||||
pending, err := client.storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, pending, 1)
|
||||
|
||||
client.sendPendingMessages(pending)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestManagerClient_sendPendingMessagesWithError tests pending message send failure.
|
||||
func TestManagerClient_sendPendingMessagesWithError(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testMsg := &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock failed send
|
||||
mockStream.On("Send", mock.Anything).Return(assert.AnError)
|
||||
|
||||
pending := []storage.Message{
|
||||
{
|
||||
Message: testMsg,
|
||||
Time: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
client.sendPendingMessages(pending)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestManagerClient_addChunkTimeout tests chunk timeout in runRequestManager.
|
||||
func TestManagerClient_addChunkTimeout(t *testing.T) {
|
||||
rm := newRunRequestManager()
|
||||
|
||||
// Add first chunk
|
||||
chunk1 := []byte("chunk1")
|
||||
buffer, complete := rm.addChunk("test-id", chunk1, false)
|
||||
assert.Nil(t, buffer)
|
||||
assert.False(t, complete)
|
||||
|
||||
// Verify request exists
|
||||
rm.mu.Lock()
|
||||
assert.Contains(t, rm.requests, "test-id")
|
||||
rm.mu.Unlock()
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(35 * time.Second) // runReqTimeout is 30 seconds
|
||||
|
||||
// Verify request was removed
|
||||
rm.mu.Lock()
|
||||
assert.NotContains(t, rm.requests, "test-id")
|
||||
rm.mu.Unlock()
|
||||
}
|
||||
|
||||
// TestManagerClient_addChunkMultiple tests adding multiple chunks.
|
||||
func TestManagerClient_addChunkMultiple(t *testing.T) {
|
||||
rm := newRunRequestManager()
|
||||
|
||||
chunk1 := []byte("chunk1")
|
||||
chunk2 := []byte("chunk2")
|
||||
chunk3 := []byte("chunk3")
|
||||
|
||||
// Add chunks
|
||||
buffer, complete := rm.addChunk("test-id", chunk1, false)
|
||||
assert.Nil(t, buffer)
|
||||
assert.False(t, complete)
|
||||
|
||||
buffer, complete = rm.addChunk("test-id", chunk2, false)
|
||||
assert.Nil(t, buffer)
|
||||
assert.False(t, complete)
|
||||
|
||||
buffer, complete = rm.addChunk("test-id", chunk3, true)
|
||||
assert.NotNil(t, buffer)
|
||||
assert.True(t, complete)
|
||||
|
||||
expected := append(append(chunk1, chunk2...), chunk3...)
|
||||
assert.Equal(t, expected, buffer)
|
||||
}
|
||||
|
||||
// TestManagerClient_handleStopComputationWithIngressProxy tests stop with ingress proxy.
|
||||
func TestManagerClient_handleStopComputationWithIngressProxy(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
mockIngressProxy := new(mockIngressProxy)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, mockIngressProxy, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stopReq := &cvms.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &cvms.StopComputation{
|
||||
ComputationId: "test-comp-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("StopComputation", mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Stop").Return(nil)
|
||||
mockIngressProxy.On("Stop").Return(nil)
|
||||
|
||||
client.handleStopComputation(context.Background(), stopReq)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
mockServerSvc.AssertExpectations(t)
|
||||
mockIngressProxy.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
}
|
||||
|
||||
// TestManagerClient_handleStopComputationWithIngressProxyError tests stop with ingress proxy error.
|
||||
func TestManagerClient_handleStopComputationWithIngressProxyError(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
mockIngressProxy := new(mockIngressProxy)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, mockIngressProxy, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stopReq := &cvms.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &cvms.StopComputation{
|
||||
ComputationId: "test-comp-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("StopComputation", mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Stop").Return(nil)
|
||||
mockIngressProxy.On("Stop").Return(assert.AnError)
|
||||
|
||||
client.handleStopComputation(context.Background(), stopReq)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockIngressProxy.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestManagerClient_sendMessage tests sendMessage with timeout.
|
||||
func TestManagerClient_sendMessage(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 1)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg := &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client.sendMessage(msg)
|
||||
|
||||
select {
|
||||
case received := <-messageQueue:
|
||||
assert.Equal(t, msg, received)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Message not received")
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerClient_sendMessageTimeout tests sendMessage timeout when queue is full.
|
||||
func TestManagerClient_sendMessageTimeout(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage) // No buffer
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg := &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Don't read from queue, so sendMessage will timeout
|
||||
client.sendMessage(msg)
|
||||
|
||||
// Should complete without blocking
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestManagerClient_handleRunReqChunksWithRemoteSource tests handling run request with remote source.
|
||||
func TestManagerClient_handleRunReqChunksWithRemoteSource(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id-remote",
|
||||
Name: "test-computation",
|
||||
Description: "test description",
|
||||
Datasets: []*cvms.Dataset{
|
||||
{
|
||||
Hash: sha3.New256().Sum([]byte("test-dataset")),
|
||||
Filename: "data.csv",
|
||||
Source: &cvms.Source{
|
||||
Type: "oci-image",
|
||||
Url: "docker://registry.example.com/data:v1",
|
||||
KbsResourcePath: "default/key/data-key",
|
||||
Encrypted: true,
|
||||
},
|
||||
Decompress: true,
|
||||
},
|
||||
},
|
||||
Algorithm: &cvms.Algorithm{
|
||||
Hash: sha3.New256().Sum([]byte("test-algorithm")),
|
||||
AlgoType: "python",
|
||||
AlgoArgs: []string{"--verbose"},
|
||||
Source: &cvms.Source{
|
||||
Type: "oci-image",
|
||||
Url: "docker://registry.example.com/algo:v1",
|
||||
KbsResourcePath: "default/key/algo-key",
|
||||
Encrypted: true,
|
||||
},
|
||||
Kbs: &cvms.KBSConfig{
|
||||
Url: "https://kbs.example.com:8080",
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
|
||||
ResultConsumers: []*cvms.ResultConsumer{
|
||||
{
|
||||
UserKey: []byte("test-consumer"),
|
||||
},
|
||||
},
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
chunk := &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{
|
||||
Id: "chunk-remote-1",
|
||||
Data: runReqBytes,
|
||||
IsLast: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("State").Return("ReceivingManifest")
|
||||
mockSvc.On("InitComputation", mock.Anything, mock.MatchedBy(func(c agent.Computation) bool {
|
||||
// Verify Algorithm KBS config is passed
|
||||
if c.Algorithm.KBS == nil || !c.Algorithm.KBS.Enabled || c.Algorithm.KBS.URL != "https://kbs.example.com:8080" {
|
||||
return false
|
||||
}
|
||||
// Verify algorithm source is passed
|
||||
if c.Algorithm.Source == nil ||
|
||||
c.Algorithm.Source.URL != "docker://registry.example.com/algo:v1" ||
|
||||
c.Algorithm.Source.KBSResourcePath != "default/key/algo-key" ||
|
||||
!c.Algorithm.Source.Encrypted {
|
||||
return false
|
||||
}
|
||||
// Verify algorithm type and args
|
||||
if c.Algorithm.AlgoType != "python" || len(c.Algorithm.AlgoArgs) != 1 || c.Algorithm.AlgoArgs[0] != "--verbose" {
|
||||
return false
|
||||
}
|
||||
// Verify dataset source is passed
|
||||
if len(c.Datasets) != 1 ||
|
||||
c.Datasets[0].Source == nil ||
|
||||
c.Datasets[0].Source.URL != "docker://registry.example.com/data:v1" ||
|
||||
c.Datasets[0].Filename != "data.csv" ||
|
||||
!c.Datasets[0].Decompress {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})).Return(nil)
|
||||
mockServerSvc.On("Start", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
err = client.handleRunReqChunks(context.Background(), chunk)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestManagerClient_handleRunReqChunksAlreadyProcessing tests skipping init when already processing.
|
||||
func TestManagerClient_handleRunReqChunksAlreadyProcessing(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id-processing",
|
||||
Name: "test-computation",
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
chunk := &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{
|
||||
Id: "chunk-processing-1",
|
||||
Data: runReqBytes,
|
||||
IsLast: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate agent already processing a computation
|
||||
mockSvc.On("State").Return("Running")
|
||||
|
||||
err = client.handleRunReqChunks(context.Background(), chunk)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// InitComputation should NOT be called since state is not ReceivingManifest
|
||||
mockSvc.AssertNotCalled(t, "InitComputation")
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
@@ -52,16 +53,20 @@ func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error {
|
||||
return errors.New("failed to get peer info")
|
||||
}
|
||||
|
||||
slog.Info("client connected to cvms server", "address", client.Addr.String())
|
||||
|
||||
eg, ctx := errgroup.WithContext(stream.Context())
|
||||
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("receive goroutine context done", "address", client.Addr.String())
|
||||
return ctx.Err()
|
||||
default:
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
slog.Error("failed to receive from stream", "address", client.Addr.String(), "error", err)
|
||||
return err
|
||||
}
|
||||
s.incoming <- req
|
||||
@@ -85,10 +90,13 @@ func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error {
|
||||
}
|
||||
|
||||
s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo)
|
||||
slog.Info("send goroutine Run() returned", "address", client.Addr.String())
|
||||
return nil
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
err := eg.Wait()
|
||||
slog.Info("stream closed", "address", client.Addr.String(), "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *grpcServer) sendRunReqInChunks(stream cvms.Service_ProcessServer, runReq *cvms.ComputationRunReq) error {
|
||||
|
||||
@@ -1,17 +1,32 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
cvms "github.com/ultravioletrs/cocos/agent/cvms"
|
||||
|
||||
storage "github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage"
|
||||
)
|
||||
|
||||
// NewStorage creates a new instance of Storage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewStorage(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Storage {
|
||||
mock := &Storage{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Storage is an autogenerated mock type for the Storage type
|
||||
type Storage struct {
|
||||
mock.Mock
|
||||
@@ -25,21 +40,20 @@ func (_m *Storage) EXPECT() *Storage_Expecter {
|
||||
return &Storage_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Add provides a mock function with given fields: msg
|
||||
func (_m *Storage) Add(msg *cvms.ClientStreamMessage) error {
|
||||
ret := _m.Called(msg)
|
||||
// Add provides a mock function for the type Storage
|
||||
func (_mock *Storage) Add(msg *cvms.ClientStreamMessage) error {
|
||||
ret := _mock.Called(msg)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Add")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(*cvms.ClientStreamMessage) error); ok {
|
||||
r0 = rf(msg)
|
||||
if returnFunc, ok := ret.Get(0).(func(*cvms.ClientStreamMessage) error); ok {
|
||||
r0 = returnFunc(msg)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -56,36 +70,41 @@ func (_e *Storage_Expecter) Add(msg interface{}) *Storage_Add_Call {
|
||||
|
||||
func (_c *Storage_Add_Call) Run(run func(msg *cvms.ClientStreamMessage)) *Storage_Add_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*cvms.ClientStreamMessage))
|
||||
var arg0 *cvms.ClientStreamMessage
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(*cvms.ClientStreamMessage)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Storage_Add_Call) Return(_a0 error) *Storage_Add_Call {
|
||||
_c.Call.Return(_a0)
|
||||
func (_c *Storage_Add_Call) Return(err error) *Storage_Add_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Storage_Add_Call) RunAndReturn(run func(*cvms.ClientStreamMessage) error) *Storage_Add_Call {
|
||||
func (_c *Storage_Add_Call) RunAndReturn(run func(msg *cvms.ClientStreamMessage) error) *Storage_Add_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Clear provides a mock function with no fields
|
||||
func (_m *Storage) Clear() error {
|
||||
ret := _m.Called()
|
||||
// Clear provides a mock function for the type Storage
|
||||
func (_mock *Storage) Clear() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Clear")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -106,8 +125,8 @@ func (_c *Storage_Clear_Call) Run(run func()) *Storage_Clear_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Storage_Clear_Call) Return(_a0 error) *Storage_Clear_Call {
|
||||
_c.Call.Return(_a0)
|
||||
func (_c *Storage_Clear_Call) Return(err error) *Storage_Clear_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -116,9 +135,9 @@ func (_c *Storage_Clear_Call) RunAndReturn(run func() error) *Storage_Clear_Call
|
||||
return _c
|
||||
}
|
||||
|
||||
// Load provides a mock function with no fields
|
||||
func (_m *Storage) Load() ([]storage.Message, error) {
|
||||
ret := _m.Called()
|
||||
// Load provides a mock function for the type Storage
|
||||
func (_mock *Storage) Load() ([]storage.Message, error) {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Load")
|
||||
@@ -126,23 +145,21 @@ func (_m *Storage) Load() ([]storage.Message, error) {
|
||||
|
||||
var r0 []storage.Message
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() ([]storage.Message, error)); ok {
|
||||
return rf()
|
||||
if returnFunc, ok := ret.Get(0).(func() ([]storage.Message, error)); ok {
|
||||
return returnFunc()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() []storage.Message); ok {
|
||||
r0 = rf()
|
||||
if returnFunc, ok := ret.Get(0).(func() []storage.Message); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]storage.Message)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
if returnFunc, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = returnFunc()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -163,8 +180,8 @@ func (_c *Storage_Load_Call) Run(run func()) *Storage_Load_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Storage_Load_Call) Return(_a0 []storage.Message, _a1 error) *Storage_Load_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
func (_c *Storage_Load_Call) Return(messages []storage.Message, err error) *Storage_Load_Call {
|
||||
_c.Call.Return(messages, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -173,21 +190,20 @@ func (_c *Storage_Load_Call) RunAndReturn(run func() ([]storage.Message, error))
|
||||
return _c
|
||||
}
|
||||
|
||||
// Save provides a mock function with given fields: messages
|
||||
func (_m *Storage) Save(messages []storage.Message) error {
|
||||
ret := _m.Called(messages)
|
||||
// Save provides a mock function for the type Storage
|
||||
func (_mock *Storage) Save(messages []storage.Message) error {
|
||||
ret := _mock.Called(messages)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Save")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]storage.Message) error); ok {
|
||||
r0 = rf(messages)
|
||||
if returnFunc, ok := ret.Get(0).(func([]storage.Message) error); ok {
|
||||
r0 = returnFunc(messages)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -204,31 +220,23 @@ func (_e *Storage_Expecter) Save(messages interface{}) *Storage_Save_Call {
|
||||
|
||||
func (_c *Storage_Save_Call) Run(run func(messages []storage.Message)) *Storage_Save_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]storage.Message))
|
||||
var arg0 []storage.Message
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].([]storage.Message)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Storage_Save_Call) Return(_a0 error) *Storage_Save_Call {
|
||||
_c.Call.Return(_a0)
|
||||
func (_c *Storage_Save_Call) Return(err error) *Storage_Save_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Storage_Save_Call) RunAndReturn(run func([]storage.Message) error) *Storage_Save_Call {
|
||||
func (_c *Storage_Save_Call) RunAndReturn(run func(messages []storage.Message) error) *Storage_Save_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewStorage creates a new instance of Storage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewStorage(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Storage {
|
||||
mock := &Storage{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
@@ -0,0 +1,450 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package storage
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
)
|
||||
|
||||
func createTempDir(t *testing.T) string {
|
||||
tmpDir, err := os.MkdirTemp("", "storage_test_*")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
os.RemoveAll(tmpDir)
|
||||
})
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
func createTestMessage(content string) *cvms.ClientStreamMessage {
|
||||
return &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
Error: "",
|
||||
ComputationId: content,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileStorage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
storageDir string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid directory",
|
||||
storageDir: createTempDir(t),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent directory gets created",
|
||||
storageDir: filepath.Join(createTempDir(t), "subdir"),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid directory path",
|
||||
storageDir: "/invalid/path/that/cannot/be/created",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
storage, err := NewFileStorage(tt.storageDir)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, storage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, storage)
|
||||
assert.Equal(t, filepath.Join(tt.storageDir, "pending_messages.json"), storage.path)
|
||||
assert.Empty(t, storage.msgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_Load(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFile func(string) error
|
||||
expectedMsgs int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "load from non-existent file",
|
||||
setupFile: func(path string) error {
|
||||
// Don't create file
|
||||
return nil
|
||||
},
|
||||
expectedMsgs: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "load from empty file",
|
||||
setupFile: func(path string) error {
|
||||
return os.WriteFile(path, []byte("[]"), 0o644)
|
||||
},
|
||||
expectedMsgs: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "load from corrupted file",
|
||||
setupFile: func(path string) error {
|
||||
return os.WriteFile(path, []byte("invalid json"), 0o644)
|
||||
},
|
||||
expectedMsgs: 0,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tt.setupFile(storage.path)
|
||||
require.NoError(t, err)
|
||||
|
||||
msgs, err := storage.Load()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msgs, tt.expectedMsgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_Save(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []Message
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "save empty messages",
|
||||
messages: []Message{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "save single message",
|
||||
messages: []Message{
|
||||
{
|
||||
Message: createTestMessage("test"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "save multiple messages",
|
||||
messages: []Message{
|
||||
{
|
||||
Message: createTestMessage("test1"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
{
|
||||
Message: createTestMessage("test2"),
|
||||
Time: time.Now().Add(time.Second),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = storage.Save(tt.messages)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify file was written correctly
|
||||
_, err := os.ReadFile(storage.path)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify internal state was updated
|
||||
assert.Equal(t, tt.messages, storage.msgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_Add(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialMsgs []Message
|
||||
newMessage *cvms.ClientStreamMessage
|
||||
expectError bool
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "add to empty storage",
|
||||
initialMsgs: []Message{},
|
||||
newMessage: createTestMessage("new"),
|
||||
expectError: false,
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "add to existing messages",
|
||||
initialMsgs: []Message{
|
||||
{
|
||||
Message: createTestMessage("existing"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
},
|
||||
newMessage: createTestMessage("new"),
|
||||
expectError: false,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "add nil message",
|
||||
initialMsgs: []Message{},
|
||||
newMessage: nil,
|
||||
expectError: false,
|
||||
expectedCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup initial messages
|
||||
if len(tt.initialMsgs) > 0 {
|
||||
err = storage.Save(tt.initialMsgs)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
beforeTime := time.Now()
|
||||
err = storage.Add(tt.newMessage)
|
||||
afterTime := time.Now()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify message was added to internal state
|
||||
assert.Len(t, storage.msgs, tt.expectedCount)
|
||||
|
||||
// Verify timestamp is reasonable
|
||||
if tt.expectedCount > 0 {
|
||||
lastMsg := storage.msgs[len(storage.msgs)-1]
|
||||
assert.True(t, lastMsg.Time.After(beforeTime) || lastMsg.Time.Equal(beforeTime))
|
||||
assert.True(t, lastMsg.Time.Before(afterTime) || lastMsg.Time.Equal(afterTime))
|
||||
assert.Equal(t, tt.newMessage, lastMsg.Message)
|
||||
}
|
||||
|
||||
_, err := os.ReadFile(storage.path)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_Clear(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialMsgs []Message
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "clear empty storage",
|
||||
initialMsgs: []Message{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "clear storage with messages",
|
||||
initialMsgs: []Message{
|
||||
{
|
||||
Message: createTestMessage("test1"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
{
|
||||
Message: createTestMessage("test2"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup initial messages
|
||||
if len(tt.initialMsgs) > 0 {
|
||||
err = storage.Save(tt.initialMsgs)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err = storage.Clear()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify internal state is cleared
|
||||
assert.Empty(t, storage.msgs)
|
||||
|
||||
// Verify file contains empty array
|
||||
data, err := os.ReadFile(storage.path)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "[]", string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_ConcurrentAccess(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test concurrent Add operations
|
||||
numGoroutines := 10
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
msg := createTestMessage(string(rune('A' + id)))
|
||||
err := storage.Add(msg)
|
||||
assert.NoError(t, err)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all messages were added
|
||||
msgs, err := storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msgs, numGoroutines)
|
||||
}
|
||||
|
||||
func TestFileStorage_IntegrationFlow(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test full workflow
|
||||
|
||||
// 1. Load from empty storage
|
||||
msgs, err := storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, msgs)
|
||||
|
||||
// 2. Add some messages
|
||||
msg1 := createTestMessage("message1")
|
||||
err = storage.Add(msg1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg2 := createTestMessage("message2")
|
||||
err = storage.Add(msg2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 3. Load and verify
|
||||
msgs, err = storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msgs, 2)
|
||||
|
||||
// 4. Save new set of messages
|
||||
newMsgs := []Message{
|
||||
{
|
||||
Message: createTestMessage("new1"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
}
|
||||
err = storage.Save(newMsgs)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5. Load and verify replacement
|
||||
msgs, err = storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msgs, 1)
|
||||
|
||||
// 6. Clear storage
|
||||
err = storage.Clear()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 7. Verify empty
|
||||
msgs, err = storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, msgs)
|
||||
}
|
||||
|
||||
func TestFileStorage_FilePermissions(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a message to create the file
|
||||
msg := createTestMessage("test")
|
||||
err = storage.Add(msg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check file permissions
|
||||
info, err := os.Stat(storage.path)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, os.FileMode(0o644), info.Mode().Perm())
|
||||
}
|
||||
|
||||
func TestFileStorage_ErrorHandling(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make directory read-only to trigger write errors
|
||||
err = os.Chmod(tmpDir, 0o555)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Restore permissions for cleanup
|
||||
t.Cleanup(func() {
|
||||
if err := os.Chmod(tmpDir, 0o755); err != nil {
|
||||
t.Errorf("Failed to restore permissions: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Try to add a message - should fail due to write permissions
|
||||
msg := createTestMessage("test")
|
||||
err = storage.Add(msg)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Try to save - should fail due to write permissions
|
||||
err = storage.Save([]Message{})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Try to clear - should fail due to write permissions
|
||||
err = storage.Clear()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
+390
-192
@@ -3,8 +3,8 @@
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.5
|
||||
// protoc v5.29.0
|
||||
// protoc-gen-go v1.36.11
|
||||
// protoc v6.33.1
|
||||
// source: agent/cvms/cvms.proto
|
||||
|
||||
package cvms
|
||||
@@ -431,6 +431,7 @@ type ClientStreamMessage struct {
|
||||
// *ClientStreamMessage_StopComputationRes
|
||||
// *ClientStreamMessage_AgentStateRes
|
||||
// *ClientStreamMessage_VTPMattestationReport
|
||||
// *ClientStreamMessage_AzureAttestationToken
|
||||
Message isClientStreamMessage_Message `protobuf_oneof:"message"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
@@ -527,6 +528,15 @@ func (x *ClientStreamMessage) GetVTPMattestationReport() *AttestationResponse {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *ClientStreamMessage) GetAzureAttestationToken() *AzureAttestationToken {
|
||||
if x != nil {
|
||||
if x, ok := x.Message.(*ClientStreamMessage_AzureAttestationToken); ok {
|
||||
return x.AzureAttestationToken
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isClientStreamMessage_Message interface {
|
||||
isClientStreamMessage_Message()
|
||||
}
|
||||
@@ -555,6 +565,10 @@ type ClientStreamMessage_VTPMattestationReport struct {
|
||||
VTPMattestationReport *AttestationResponse `protobuf:"bytes,6,opt,name=vTPMattestationReport,proto3,oneof"`
|
||||
}
|
||||
|
||||
type ClientStreamMessage_AzureAttestationToken struct {
|
||||
AzureAttestationToken *AzureAttestationToken `protobuf:"bytes,7,opt,name=azureAttestationToken,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*ClientStreamMessage_AgentLog) isClientStreamMessage_Message() {}
|
||||
|
||||
func (*ClientStreamMessage_AgentEvent) isClientStreamMessage_Message() {}
|
||||
@@ -567,6 +581,8 @@ func (*ClientStreamMessage_AgentStateRes) isClientStreamMessage_Message() {}
|
||||
|
||||
func (*ClientStreamMessage_VTPMattestationReport) isClientStreamMessage_Message() {}
|
||||
|
||||
func (*ClientStreamMessage_AzureAttestationToken) isClientStreamMessage_Message() {}
|
||||
|
||||
type ServerStreamMessage struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Types that are valid to be assigned to Message:
|
||||
@@ -942,6 +958,9 @@ type Dataset struct {
|
||||
Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` // should be sha3.Sum256, 32 byte length.
|
||||
UserKey []byte `protobuf:"bytes,2,opt,name=userKey,proto3" json:"userKey,omitempty"`
|
||||
Filename string `protobuf:"bytes,3,opt,name=filename,proto3" json:"filename,omitempty"`
|
||||
Source *Source `protobuf:"bytes,4,opt,name=source,proto3" json:"source,omitempty"` // Optional remote source for encrypted dataset
|
||||
Decompress bool `protobuf:"varint,5,opt,name=decompress,proto3" json:"decompress,omitempty"`
|
||||
Kbs *KBSConfig `protobuf:"bytes,6,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration override
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -997,10 +1016,35 @@ func (x *Dataset) GetFilename() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Dataset) GetSource() *Source {
|
||||
if x != nil {
|
||||
return x.Source
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Dataset) GetDecompress() bool {
|
||||
if x != nil {
|
||||
return x.Decompress
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *Dataset) GetKbs() *KBSConfig {
|
||||
if x != nil {
|
||||
return x.Kbs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Algorithm struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` // should be sha3.Sum256, 32 byte length.
|
||||
UserKey []byte `protobuf:"bytes,2,opt,name=userKey,proto3" json:"userKey,omitempty"`
|
||||
Source *Source `protobuf:"bytes,3,opt,name=source,proto3" json:"source,omitempty"` // Optional remote source for encrypted algorithm
|
||||
AlgoType string `protobuf:"bytes,4,opt,name=algo_type,json=algoType,proto3" json:"algo_type,omitempty"`
|
||||
AlgoArgs []string `protobuf:"bytes,5,rep,name=algo_args,json=algoArgs,proto3" json:"algo_args,omitempty"`
|
||||
Kbs *KBSConfig `protobuf:"bytes,6,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration override
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1049,6 +1093,154 @@ func (x *Algorithm) GetUserKey() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Algorithm) GetSource() *Source {
|
||||
if x != nil {
|
||||
return x.Source
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Algorithm) GetAlgoType() string {
|
||||
if x != nil {
|
||||
return x.AlgoType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Algorithm) GetAlgoArgs() []string {
|
||||
if x != nil {
|
||||
return x.AlgoArgs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Algorithm) GetKbs() *KBSConfig {
|
||||
if x != nil {
|
||||
return x.Kbs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Source struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` // Type of source: "oci-image" (only OCI images supported for CoCo)
|
||||
Url string `protobuf:"bytes,2,opt,name=url,proto3" json:"url,omitempty"` // URL of the OCI image (e.g., docker://registry/repo:tag)
|
||||
KbsResourcePath string `protobuf:"bytes,3,opt,name=kbs_resource_path,json=kbsResourcePath,proto3" json:"kbs_resource_path,omitempty"` // Path to decryption key in KBS (e.g., "default/key/my-key")
|
||||
Encrypted bool `protobuf:"varint,4,opt,name=encrypted,proto3" json:"encrypted,omitempty"` // Whether the resource is encrypted (requires KBS)
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *Source) Reset() {
|
||||
*x = Source{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[15]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *Source) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Source) ProtoMessage() {}
|
||||
|
||||
func (x *Source) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[15]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Source.ProtoReflect.Descriptor instead.
|
||||
func (*Source) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{15}
|
||||
}
|
||||
|
||||
func (x *Source) GetType() string {
|
||||
if x != nil {
|
||||
return x.Type
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Source) GetUrl() string {
|
||||
if x != nil {
|
||||
return x.Url
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Source) GetKbsResourcePath() string {
|
||||
if x != nil {
|
||||
return x.KbsResourcePath
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Source) GetEncrypted() bool {
|
||||
if x != nil {
|
||||
return x.Encrypted
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type KBSConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` // KBS endpoint URL (e.g., "https://kbs.example.com")
|
||||
Enabled bool `protobuf:"varint,2,opt,name=enabled,proto3" json:"enabled,omitempty"` // Whether to use KBS for key retrieval
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *KBSConfig) Reset() {
|
||||
*x = KBSConfig{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[16]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *KBSConfig) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*KBSConfig) ProtoMessage() {}
|
||||
|
||||
func (x *KBSConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[16]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use KBSConfig.ProtoReflect.Descriptor instead.
|
||||
func (*KBSConfig) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{16}
|
||||
}
|
||||
|
||||
func (x *KBSConfig) GetUrl() string {
|
||||
if x != nil {
|
||||
return x.Url
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *KBSConfig) GetEnabled() bool {
|
||||
if x != nil {
|
||||
return x.Enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Port string `protobuf:"bytes,1,opt,name=port,proto3" json:"port,omitempty"`
|
||||
@@ -1064,7 +1256,7 @@ type AgentConfig struct {
|
||||
|
||||
func (x *AgentConfig) Reset() {
|
||||
*x = AgentConfig{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[15]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[17]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -1076,7 +1268,7 @@ func (x *AgentConfig) String() string {
|
||||
func (*AgentConfig) ProtoMessage() {}
|
||||
|
||||
func (x *AgentConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[15]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[17]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -1089,7 +1281,7 @@ func (x *AgentConfig) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use AgentConfig.ProtoReflect.Descriptor instead.
|
||||
func (*AgentConfig) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{15}
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{17}
|
||||
}
|
||||
|
||||
func (x *AgentConfig) GetPort() string {
|
||||
@@ -1151,7 +1343,7 @@ type AttestationResponse struct {
|
||||
|
||||
func (x *AttestationResponse) Reset() {
|
||||
*x = AttestationResponse{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[16]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[18]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -1163,7 +1355,7 @@ func (x *AttestationResponse) String() string {
|
||||
func (*AttestationResponse) ProtoMessage() {}
|
||||
|
||||
func (x *AttestationResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[16]
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[18]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -1176,7 +1368,7 @@ func (x *AttestationResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use AttestationResponse.ProtoReflect.Descriptor instead.
|
||||
func (*AttestationResponse) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{16}
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{18}
|
||||
}
|
||||
|
||||
func (x *AttestationResponse) GetFile() []byte {
|
||||
@@ -1193,168 +1385,165 @@ func (x *AttestationResponse) GetCertSerialNumber() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type AzureAttestationToken struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
|
||||
CertSerialNumber string `protobuf:"bytes,2,opt,name=certSerialNumber,proto3" json:"certSerialNumber,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AzureAttestationToken) Reset() {
|
||||
*x = AzureAttestationToken{}
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[19]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AzureAttestationToken) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*AzureAttestationToken) ProtoMessage() {}
|
||||
|
||||
func (x *AzureAttestationToken) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_cvms_cvms_proto_msgTypes[19]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use AzureAttestationToken.ProtoReflect.Descriptor instead.
|
||||
func (*AzureAttestationToken) Descriptor() ([]byte, []int) {
|
||||
return file_agent_cvms_cvms_proto_rawDescGZIP(), []int{19}
|
||||
}
|
||||
|
||||
func (x *AzureAttestationToken) GetFile() []byte {
|
||||
if x != nil {
|
||||
return x.File
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *AzureAttestationToken) GetCertSerialNumber() string {
|
||||
if x != nil {
|
||||
return x.CertSerialNumber
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_agent_cvms_cvms_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_agent_cvms_cvms_proto_rawDesc = string([]byte{
|
||||
0x0a, 0x15, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x63, 0x76, 0x6d, 0x73, 0x2f, 0x63, 0x76, 0x6d,
|
||||
0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x76, 0x6d, 0x73, 0x1a, 0x1f, 0x67,
|
||||
0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74,
|
||||
0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x1f,
|
||||
0x0a, 0x0d, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x12,
|
||||
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x22,
|
||||
0x35, 0x0a, 0x0d, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73,
|
||||
0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64,
|
||||
0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x22, 0x38, 0x0a, 0x0f, 0x53, 0x74, 0x6f, 0x70, 0x43, 0x6f,
|
||||
0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6f, 0x6d,
|
||||
0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64,
|
||||
0x22, 0x5a, 0x0a, 0x17, 0x53, 0x74, 0x6f, 0x70, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74,
|
||||
0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63,
|
||||
0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x4a, 0x0a, 0x0b,
|
||||
0x52, 0x75, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63,
|
||||
0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xde, 0x01, 0x0a, 0x0a, 0x41, 0x67, 0x65,
|
||||
0x6e, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x76, 0x65, 0x6e, 0x74,
|
||||
0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x65, 0x76, 0x65,
|
||||
0x6e, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
|
||||
0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
|
||||
0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65,
|
||||
0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70,
|
||||
0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f,
|
||||
0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74,
|
||||
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69,
|
||||
0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c,
|
||||
0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x18,
|
||||
0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x6f,
|
||||
0x72, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x9b, 0x01, 0x0a, 0x08, 0x41, 0x67,
|
||||
0x65, 0x6e, 0x74, 0x4c, 0x6f, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
||||
0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||
0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f,
|
||||
0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74,
|
||||
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c,
|
||||
0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x38, 0x0a,
|
||||
0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b,
|
||||
0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
|
||||
0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69,
|
||||
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x22, 0x93, 0x03, 0x0a, 0x13, 0x43, 0x6c, 0x69, 0x65,
|
||||
0x6e, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12,
|
||||
0x2d, 0x0a, 0x09, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x6c, 0x6f, 0x67, 0x18, 0x01, 0x20, 0x01,
|
||||
0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x4c,
|
||||
0x6f, 0x67, 0x48, 0x00, 0x52, 0x08, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x4c, 0x6f, 0x67, 0x12, 0x33,
|
||||
0x0a, 0x0b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74,
|
||||
0x45, 0x76, 0x65, 0x6e, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x45, 0x76,
|
||||
0x65, 0x6e, 0x74, 0x12, 0x2c, 0x0a, 0x07, 0x72, 0x75, 0x6e, 0x5f, 0x72, 0x65, 0x73, 0x18, 0x03,
|
||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x52, 0x75, 0x6e, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x06, 0x72, 0x75, 0x6e, 0x52, 0x65,
|
||||
0x73, 0x12, 0x4f, 0x0a, 0x12, 0x73, 0x74, 0x6f, 0x70, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61,
|
||||
0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e,
|
||||
0x63, 0x76, 0x6d, 0x73, 0x2e, 0x53, 0x74, 0x6f, 0x70, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61,
|
||||
0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x12,
|
||||
0x73, 0x74, 0x6f, 0x70, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52,
|
||||
0x65, 0x73, 0x12, 0x3b, 0x0a, 0x0d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65,
|
||||
0x52, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x63, 0x76, 0x6d, 0x73,
|
||||
0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x48, 0x00,
|
||||
0x52, 0x0d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x12,
|
||||
0x51, 0x0a, 0x15, 0x76, 0x54, 0x50, 0x4d, 0x61, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69,
|
||||
0x6f, 0x6e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19,
|
||||
0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f,
|
||||
0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x15, 0x76, 0x54, 0x50,
|
||||
0x4d, 0x61, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x70, 0x6f,
|
||||
0x72, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xca, 0x02,
|
||||
0x0a, 0x13, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x4d, 0x65,
|
||||
0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x38, 0x0a, 0x0c, 0x72, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x43,
|
||||
0x68, 0x75, 0x6e, 0x6b, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x63, 0x76,
|
||||
0x6d, 0x73, 0x2e, 0x52, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x73, 0x48,
|
||||
0x00, 0x52, 0x0c, 0x72, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x73, 0x12,
|
||||
0x31, 0x0a, 0x06, 0x72, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
|
||||
0x17, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69,
|
||||
0x6f, 0x6e, 0x52, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x48, 0x00, 0x52, 0x06, 0x72, 0x75, 0x6e, 0x52,
|
||||
0x65, 0x71, 0x12, 0x41, 0x0a, 0x0f, 0x73, 0x74, 0x6f, 0x70, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74,
|
||||
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x63, 0x76,
|
||||
0x6d, 0x73, 0x2e, 0x53, 0x74, 0x6f, 0x70, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69,
|
||||
0x6f, 0x6e, 0x48, 0x00, 0x52, 0x0f, 0x73, 0x74, 0x6f, 0x70, 0x43, 0x6f, 0x6d, 0x70, 0x75, 0x74,
|
||||
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x3b, 0x0a, 0x0d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x74,
|
||||
0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x63,
|
||||
0x76, 0x6d, 0x73, 0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65,
|
||||
0x71, 0x48, 0x00, 0x52, 0x0d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52,
|
||||
0x65, 0x71, 0x12, 0x3b, 0x0a, 0x0d, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74,
|
||||
0x52, 0x65, 0x71, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x63, 0x76, 0x6d, 0x73,
|
||||
0x2e, 0x44, 0x69, 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x52, 0x65, 0x71, 0x48, 0x00,
|
||||
0x52, 0x0d, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x52, 0x65, 0x71, 0x42,
|
||||
0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x1f, 0x0a, 0x0d, 0x44, 0x69,
|
||||
0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x52, 0x65, 0x71, 0x12, 0x0e, 0x0a, 0x02, 0x69,
|
||||
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x22, 0x4b, 0x0a, 0x0c, 0x52,
|
||||
0x75, 0x6e, 0x52, 0x65, 0x71, 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x64,
|
||||
0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12,
|
||||
0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12,
|
||||
0x17, 0x0a, 0x07, 0x69, 0x73, 0x5f, 0x6c, 0x61, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08,
|
||||
0x52, 0x06, 0x69, 0x73, 0x4c, 0x61, 0x73, 0x74, 0x22, 0xaa, 0x02, 0x0a, 0x11, 0x43, 0x6f, 0x6d,
|
||||
0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x75, 0x6e, 0x52, 0x65, 0x71, 0x12, 0x0e,
|
||||
0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12,
|
||||
0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61,
|
||||
0x6d, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f,
|
||||
0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70,
|
||||
0x74, 0x69, 0x6f, 0x6e, 0x12, 0x29, 0x0a, 0x08, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73,
|
||||
0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x44, 0x61,
|
||||
0x74, 0x61, 0x73, 0x65, 0x74, 0x52, 0x08, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x73, 0x12,
|
||||
0x2d, 0x0a, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x05, 0x20, 0x01,
|
||||
0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69,
|
||||
0x74, 0x68, 0x6d, 0x52, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x3f,
|
||||
0x0a, 0x10, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65,
|
||||
0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e,
|
||||
0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x43, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x52, 0x0f,
|
||||
0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x43, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x73, 0x12,
|
||||
0x34, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18,
|
||||
0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x41, 0x67, 0x65,
|
||||
0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x43,
|
||||
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x2a, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x43,
|
||||
0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b,
|
||||
0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65,
|
||||
0x79, 0x22, 0x53, 0x0a, 0x07, 0x44, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, 0x12, 0x12, 0x0a, 0x04,
|
||||
0x68, 0x61, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x68, 0x61, 0x73, 0x68,
|
||||
0x12, 0x18, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28,
|
||||
0x0c, 0x52, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x66, 0x69,
|
||||
0x6c, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69,
|
||||
0x6c, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x39, 0x0a, 0x09, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69,
|
||||
0x74, 0x68, 0x6d, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x61, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28,
|
||||
0x0c, 0x52, 0x04, 0x68, 0x61, 0x73, 0x68, 0x12, 0x18, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b,
|
||||
0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x75, 0x73, 0x65, 0x72, 0x4b, 0x65,
|
||||
0x79, 0x22, 0xe5, 0x01, 0x0a, 0x0b, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69,
|
||||
0x67, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x66, 0x69,
|
||||
0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x65, 0x72, 0x74, 0x46, 0x69,
|
||||
0x6c, 0x65, 0x12, 0x19, 0x0a, 0x08, 0x6b, 0x65, 0x79, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x03,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6b, 0x65, 0x79, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x24, 0x0a,
|
||||
0x0e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x63, 0x61, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18,
|
||||
0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x61, 0x46,
|
||||
0x69, 0x6c, 0x65, 0x12, 0x24, 0x0a, 0x0e, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x63, 0x61,
|
||||
0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x72,
|
||||
0x76, 0x65, 0x72, 0x43, 0x61, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x6c, 0x6f, 0x67,
|
||||
0x5f, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6c, 0x6f,
|
||||
0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x74, 0x74, 0x65, 0x73, 0x74,
|
||||
0x65, 0x64, 0x5f, 0x74, 0x6c, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, 0x74,
|
||||
0x74, 0x65, 0x73, 0x74, 0x65, 0x64, 0x54, 0x6c, 0x73, 0x22, 0x55, 0x0a, 0x13, 0x41, 0x74, 0x74,
|
||||
0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04,
|
||||
0x66, 0x69, 0x6c, 0x65, 0x12, 0x2a, 0x0a, 0x10, 0x63, 0x65, 0x72, 0x74, 0x53, 0x65, 0x72, 0x69,
|
||||
0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10,
|
||||
0x63, 0x65, 0x72, 0x74, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72,
|
||||
0x32, 0x50, 0x0a, 0x07, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x07, 0x50,
|
||||
0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x43, 0x6c,
|
||||
0x69, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
||||
0x65, 0x1a, 0x19, 0x2e, 0x63, 0x76, 0x6d, 0x73, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53,
|
||||
0x74, 0x72, 0x65, 0x61, 0x6d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01,
|
||||
0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2e, 0x2f, 0x63, 0x76, 0x6d, 0x73, 0x62, 0x06, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x33,
|
||||
})
|
||||
const file_agent_cvms_cvms_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x15agent/cvms/cvms.proto\x12\x04cvms\x1a\x1fgoogle/protobuf/timestamp.proto\"\x1f\n" +
|
||||
"\rAgentStateReq\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\"5\n" +
|
||||
"\rAgentStateRes\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\x12\x14\n" +
|
||||
"\x05state\x18\x02 \x01(\tR\x05state\"8\n" +
|
||||
"\x0fStopComputation\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\"Z\n" +
|
||||
"\x17StopComputationResponse\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x18\n" +
|
||||
"\amessage\x18\x02 \x01(\tR\amessage\"J\n" +
|
||||
"\vRunResponse\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x14\n" +
|
||||
"\x05error\x18\x02 \x01(\tR\x05error\"\xde\x01\n" +
|
||||
"\n" +
|
||||
"AgentEvent\x12\x1d\n" +
|
||||
"\n" +
|
||||
"event_type\x18\x01 \x01(\tR\teventType\x128\n" +
|
||||
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x03 \x01(\tR\rcomputationId\x12\x18\n" +
|
||||
"\adetails\x18\x04 \x01(\fR\adetails\x12\x1e\n" +
|
||||
"\n" +
|
||||
"originator\x18\x05 \x01(\tR\n" +
|
||||
"originator\x12\x16\n" +
|
||||
"\x06status\x18\x06 \x01(\tR\x06status\"\x9b\x01\n" +
|
||||
"\bAgentLog\x12\x18\n" +
|
||||
"\amessage\x18\x01 \x01(\tR\amessage\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x02 \x01(\tR\rcomputationId\x12\x14\n" +
|
||||
"\x05level\x18\x03 \x01(\tR\x05level\x128\n" +
|
||||
"\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\xe8\x03\n" +
|
||||
"\x13ClientStreamMessage\x12-\n" +
|
||||
"\tagent_log\x18\x01 \x01(\v2\x0e.cvms.AgentLogH\x00R\bagentLog\x123\n" +
|
||||
"\vagent_event\x18\x02 \x01(\v2\x10.cvms.AgentEventH\x00R\n" +
|
||||
"agentEvent\x12,\n" +
|
||||
"\arun_res\x18\x03 \x01(\v2\x11.cvms.RunResponseH\x00R\x06runRes\x12O\n" +
|
||||
"\x12stopComputationRes\x18\x04 \x01(\v2\x1d.cvms.StopComputationResponseH\x00R\x12stopComputationRes\x12;\n" +
|
||||
"\ragentStateRes\x18\x05 \x01(\v2\x13.cvms.AgentStateResH\x00R\ragentStateRes\x12Q\n" +
|
||||
"\x15vTPMattestationReport\x18\x06 \x01(\v2\x19.cvms.AttestationResponseH\x00R\x15vTPMattestationReport\x12S\n" +
|
||||
"\x15azureAttestationToken\x18\a \x01(\v2\x1b.cvms.azureAttestationTokenH\x00R\x15azureAttestationTokenB\t\n" +
|
||||
"\amessage\"\xca\x02\n" +
|
||||
"\x13ServerStreamMessage\x128\n" +
|
||||
"\frunReqChunks\x18\x01 \x01(\v2\x12.cvms.RunReqChunksH\x00R\frunReqChunks\x121\n" +
|
||||
"\x06runReq\x18\x02 \x01(\v2\x17.cvms.ComputationRunReqH\x00R\x06runReq\x12A\n" +
|
||||
"\x0fstopComputation\x18\x03 \x01(\v2\x15.cvms.StopComputationH\x00R\x0fstopComputation\x12;\n" +
|
||||
"\ragentStateReq\x18\x04 \x01(\v2\x13.cvms.AgentStateReqH\x00R\ragentStateReq\x12;\n" +
|
||||
"\rdisconnectReq\x18\x05 \x01(\v2\x13.cvms.DisconnectReqH\x00R\rdisconnectReqB\t\n" +
|
||||
"\amessage\"\x1f\n" +
|
||||
"\rDisconnectReq\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\"K\n" +
|
||||
"\fRunReqChunks\x12\x12\n" +
|
||||
"\x04data\x18\x01 \x01(\fR\x04data\x12\x0e\n" +
|
||||
"\x02id\x18\x02 \x01(\tR\x02id\x12\x17\n" +
|
||||
"\ais_last\x18\x03 \x01(\bR\x06isLast\"\xaa\x02\n" +
|
||||
"\x11ComputationRunReq\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" +
|
||||
"\x04name\x18\x02 \x01(\tR\x04name\x12 \n" +
|
||||
"\vdescription\x18\x03 \x01(\tR\vdescription\x12)\n" +
|
||||
"\bdatasets\x18\x04 \x03(\v2\r.cvms.DatasetR\bdatasets\x12-\n" +
|
||||
"\talgorithm\x18\x05 \x01(\v2\x0f.cvms.AlgorithmR\talgorithm\x12?\n" +
|
||||
"\x10result_consumers\x18\x06 \x03(\v2\x14.cvms.ResultConsumerR\x0fresultConsumers\x124\n" +
|
||||
"\fagent_config\x18\a \x01(\v2\x11.cvms.AgentConfigR\vagentConfig\"*\n" +
|
||||
"\x0eResultConsumer\x12\x18\n" +
|
||||
"\auserKey\x18\x01 \x01(\fR\auserKey\"\xbc\x01\n" +
|
||||
"\aDataset\x12\x12\n" +
|
||||
"\x04hash\x18\x01 \x01(\fR\x04hash\x12\x18\n" +
|
||||
"\auserKey\x18\x02 \x01(\fR\auserKey\x12\x1a\n" +
|
||||
"\bfilename\x18\x03 \x01(\tR\bfilename\x12$\n" +
|
||||
"\x06source\x18\x04 \x01(\v2\f.cvms.SourceR\x06source\x12\x1e\n" +
|
||||
"\n" +
|
||||
"decompress\x18\x05 \x01(\bR\n" +
|
||||
"decompress\x12!\n" +
|
||||
"\x03kbs\x18\x06 \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"\xbc\x01\n" +
|
||||
"\tAlgorithm\x12\x12\n" +
|
||||
"\x04hash\x18\x01 \x01(\fR\x04hash\x12\x18\n" +
|
||||
"\auserKey\x18\x02 \x01(\fR\auserKey\x12$\n" +
|
||||
"\x06source\x18\x03 \x01(\v2\f.cvms.SourceR\x06source\x12\x1b\n" +
|
||||
"\talgo_type\x18\x04 \x01(\tR\balgoType\x12\x1b\n" +
|
||||
"\talgo_args\x18\x05 \x03(\tR\balgoArgs\x12!\n" +
|
||||
"\x03kbs\x18\x06 \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"x\n" +
|
||||
"\x06Source\x12\x12\n" +
|
||||
"\x04type\x18\x01 \x01(\tR\x04type\x12\x10\n" +
|
||||
"\x03url\x18\x02 \x01(\tR\x03url\x12*\n" +
|
||||
"\x11kbs_resource_path\x18\x03 \x01(\tR\x0fkbsResourcePath\x12\x1c\n" +
|
||||
"\tencrypted\x18\x04 \x01(\bR\tencrypted\"7\n" +
|
||||
"\tKBSConfig\x12\x10\n" +
|
||||
"\x03url\x18\x01 \x01(\tR\x03url\x12\x18\n" +
|
||||
"\aenabled\x18\x02 \x01(\bR\aenabled\"\xe5\x01\n" +
|
||||
"\vAgentConfig\x12\x12\n" +
|
||||
"\x04port\x18\x01 \x01(\tR\x04port\x12\x1b\n" +
|
||||
"\tcert_file\x18\x02 \x01(\tR\bcertFile\x12\x19\n" +
|
||||
"\bkey_file\x18\x03 \x01(\tR\akeyFile\x12$\n" +
|
||||
"\x0eclient_ca_file\x18\x04 \x01(\tR\fclientCaFile\x12$\n" +
|
||||
"\x0eserver_ca_file\x18\x05 \x01(\tR\fserverCaFile\x12\x1b\n" +
|
||||
"\tlog_level\x18\x06 \x01(\tR\blogLevel\x12!\n" +
|
||||
"\fattested_tls\x18\a \x01(\bR\vattestedTls\"U\n" +
|
||||
"\x13AttestationResponse\x12\x12\n" +
|
||||
"\x04file\x18\x01 \x01(\fR\x04file\x12*\n" +
|
||||
"\x10certSerialNumber\x18\x02 \x01(\tR\x10certSerialNumber\"W\n" +
|
||||
"\x15azureAttestationToken\x12\x12\n" +
|
||||
"\x04file\x18\x01 \x01(\fR\x04file\x12*\n" +
|
||||
"\x10certSerialNumber\x18\x02 \x01(\tR\x10certSerialNumber2P\n" +
|
||||
"\aService\x12E\n" +
|
||||
"\aProcess\x12\x19.cvms.ClientStreamMessage\x1a\x19.cvms.ServerStreamMessage\"\x00(\x010\x01B\bZ\x06./cvmsb\x06proto3"
|
||||
|
||||
var (
|
||||
file_agent_cvms_cvms_proto_rawDescOnce sync.Once
|
||||
@@ -1368,7 +1557,7 @@ func file_agent_cvms_cvms_proto_rawDescGZIP() []byte {
|
||||
return file_agent_cvms_cvms_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_cvms_cvms_proto_msgTypes = make([]protoimpl.MessageInfo, 17)
|
||||
var file_agent_cvms_cvms_proto_msgTypes = make([]protoimpl.MessageInfo, 20)
|
||||
var file_agent_cvms_cvms_proto_goTypes = []any{
|
||||
(*AgentStateReq)(nil), // 0: cvms.AgentStateReq
|
||||
(*AgentStateRes)(nil), // 1: cvms.AgentStateRes
|
||||
@@ -1385,35 +1574,43 @@ var file_agent_cvms_cvms_proto_goTypes = []any{
|
||||
(*ResultConsumer)(nil), // 12: cvms.ResultConsumer
|
||||
(*Dataset)(nil), // 13: cvms.Dataset
|
||||
(*Algorithm)(nil), // 14: cvms.Algorithm
|
||||
(*AgentConfig)(nil), // 15: cvms.AgentConfig
|
||||
(*AttestationResponse)(nil), // 16: cvms.AttestationResponse
|
||||
(*timestamppb.Timestamp)(nil), // 17: google.protobuf.Timestamp
|
||||
(*Source)(nil), // 15: cvms.Source
|
||||
(*KBSConfig)(nil), // 16: cvms.KBSConfig
|
||||
(*AgentConfig)(nil), // 17: cvms.AgentConfig
|
||||
(*AttestationResponse)(nil), // 18: cvms.AttestationResponse
|
||||
(*AzureAttestationToken)(nil), // 19: cvms.azureAttestationToken
|
||||
(*timestamppb.Timestamp)(nil), // 20: google.protobuf.Timestamp
|
||||
}
|
||||
var file_agent_cvms_cvms_proto_depIdxs = []int32{
|
||||
17, // 0: cvms.AgentEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
17, // 1: cvms.AgentLog.timestamp:type_name -> google.protobuf.Timestamp
|
||||
20, // 0: cvms.AgentEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
20, // 1: cvms.AgentLog.timestamp:type_name -> google.protobuf.Timestamp
|
||||
6, // 2: cvms.ClientStreamMessage.agent_log:type_name -> cvms.AgentLog
|
||||
5, // 3: cvms.ClientStreamMessage.agent_event:type_name -> cvms.AgentEvent
|
||||
4, // 4: cvms.ClientStreamMessage.run_res:type_name -> cvms.RunResponse
|
||||
3, // 5: cvms.ClientStreamMessage.stopComputationRes:type_name -> cvms.StopComputationResponse
|
||||
1, // 6: cvms.ClientStreamMessage.agentStateRes:type_name -> cvms.AgentStateRes
|
||||
16, // 7: cvms.ClientStreamMessage.vTPMattestationReport:type_name -> cvms.AttestationResponse
|
||||
10, // 8: cvms.ServerStreamMessage.runReqChunks:type_name -> cvms.RunReqChunks
|
||||
11, // 9: cvms.ServerStreamMessage.runReq:type_name -> cvms.ComputationRunReq
|
||||
2, // 10: cvms.ServerStreamMessage.stopComputation:type_name -> cvms.StopComputation
|
||||
0, // 11: cvms.ServerStreamMessage.agentStateReq:type_name -> cvms.AgentStateReq
|
||||
9, // 12: cvms.ServerStreamMessage.disconnectReq:type_name -> cvms.DisconnectReq
|
||||
13, // 13: cvms.ComputationRunReq.datasets:type_name -> cvms.Dataset
|
||||
14, // 14: cvms.ComputationRunReq.algorithm:type_name -> cvms.Algorithm
|
||||
12, // 15: cvms.ComputationRunReq.result_consumers:type_name -> cvms.ResultConsumer
|
||||
15, // 16: cvms.ComputationRunReq.agent_config:type_name -> cvms.AgentConfig
|
||||
7, // 17: cvms.Service.Process:input_type -> cvms.ClientStreamMessage
|
||||
8, // 18: cvms.Service.Process:output_type -> cvms.ServerStreamMessage
|
||||
18, // [18:19] is the sub-list for method output_type
|
||||
17, // [17:18] is the sub-list for method input_type
|
||||
17, // [17:17] is the sub-list for extension type_name
|
||||
17, // [17:17] is the sub-list for extension extendee
|
||||
0, // [0:17] is the sub-list for field type_name
|
||||
18, // 7: cvms.ClientStreamMessage.vTPMattestationReport:type_name -> cvms.AttestationResponse
|
||||
19, // 8: cvms.ClientStreamMessage.azureAttestationToken:type_name -> cvms.azureAttestationToken
|
||||
10, // 9: cvms.ServerStreamMessage.runReqChunks:type_name -> cvms.RunReqChunks
|
||||
11, // 10: cvms.ServerStreamMessage.runReq:type_name -> cvms.ComputationRunReq
|
||||
2, // 11: cvms.ServerStreamMessage.stopComputation:type_name -> cvms.StopComputation
|
||||
0, // 12: cvms.ServerStreamMessage.agentStateReq:type_name -> cvms.AgentStateReq
|
||||
9, // 13: cvms.ServerStreamMessage.disconnectReq:type_name -> cvms.DisconnectReq
|
||||
13, // 14: cvms.ComputationRunReq.datasets:type_name -> cvms.Dataset
|
||||
14, // 15: cvms.ComputationRunReq.algorithm:type_name -> cvms.Algorithm
|
||||
12, // 16: cvms.ComputationRunReq.result_consumers:type_name -> cvms.ResultConsumer
|
||||
17, // 17: cvms.ComputationRunReq.agent_config:type_name -> cvms.AgentConfig
|
||||
15, // 18: cvms.Dataset.source:type_name -> cvms.Source
|
||||
16, // 19: cvms.Dataset.kbs:type_name -> cvms.KBSConfig
|
||||
15, // 20: cvms.Algorithm.source:type_name -> cvms.Source
|
||||
16, // 21: cvms.Algorithm.kbs:type_name -> cvms.KBSConfig
|
||||
7, // 22: cvms.Service.Process:input_type -> cvms.ClientStreamMessage
|
||||
8, // 23: cvms.Service.Process:output_type -> cvms.ServerStreamMessage
|
||||
23, // [23:24] is the sub-list for method output_type
|
||||
22, // [22:23] is the sub-list for method input_type
|
||||
22, // [22:22] is the sub-list for extension type_name
|
||||
22, // [22:22] is the sub-list for extension extendee
|
||||
0, // [0:22] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_cvms_cvms_proto_init() }
|
||||
@@ -1428,6 +1625,7 @@ func file_agent_cvms_cvms_proto_init() {
|
||||
(*ClientStreamMessage_StopComputationRes)(nil),
|
||||
(*ClientStreamMessage_AgentStateRes)(nil),
|
||||
(*ClientStreamMessage_VTPMattestationReport)(nil),
|
||||
(*ClientStreamMessage_AzureAttestationToken)(nil),
|
||||
}
|
||||
file_agent_cvms_cvms_proto_msgTypes[8].OneofWrappers = []any{
|
||||
(*ServerStreamMessage_RunReqChunks)(nil),
|
||||
@@ -1442,7 +1640,7 @@ func file_agent_cvms_cvms_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_cvms_cvms_proto_rawDesc), len(file_agent_cvms_cvms_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 17,
|
||||
NumMessages: 20,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -60,6 +60,7 @@ message ClientStreamMessage {
|
||||
StopComputationResponse stopComputationRes = 4;
|
||||
AgentStateRes agentStateRes = 5;
|
||||
AttestationResponse vTPMattestationReport = 6;
|
||||
azureAttestationToken azureAttestationToken = 7;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,11 +102,30 @@ message Dataset {
|
||||
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
|
||||
bytes userKey = 2;
|
||||
string filename = 3;
|
||||
Source source = 4; // Optional remote source for encrypted dataset
|
||||
bool decompress = 5;
|
||||
KBSConfig kbs = 6; // Optional KBS configuration override
|
||||
}
|
||||
|
||||
message Algorithm {
|
||||
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
|
||||
bytes userKey = 2;
|
||||
Source source = 3; // Optional remote source for encrypted algorithm
|
||||
string algo_type = 4;
|
||||
repeated string algo_args = 5;
|
||||
KBSConfig kbs = 6; // Optional KBS configuration override
|
||||
}
|
||||
|
||||
message Source {
|
||||
string type = 1; // Type of source: "oci-image", "s3", "gcs", "https", "http"
|
||||
string url = 2; // URL of the resource (e.g., docker://registry/repo:tag, s3://bucket/key, https://host/path)
|
||||
string kbs_resource_path = 3; // Path to decryption key in KBS (e.g., "default/key/my-key")
|
||||
bool encrypted = 4; // Whether the resource is encrypted (requires KBS)
|
||||
}
|
||||
|
||||
message KBSConfig {
|
||||
string url = 1; // KBS endpoint URL (e.g., "https://kbs.example.com")
|
||||
bool enabled = 2; // Whether to use KBS for key retrieval
|
||||
}
|
||||
|
||||
message AgentConfig {
|
||||
@@ -122,3 +142,8 @@ message AttestationResponse {
|
||||
bytes file = 1;
|
||||
string certSerialNumber = 2;
|
||||
}
|
||||
|
||||
message azureAttestationToken {
|
||||
bytes file = 1;
|
||||
string certSerialNumber = 2;
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v5.29.0
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc v6.33.1
|
||||
// source: agent/cvms/cvms.proto
|
||||
|
||||
package cvms
|
||||
@@ -69,7 +69,7 @@ type ServiceServer interface {
|
||||
type UnimplementedServiceServer struct{}
|
||||
|
||||
func (UnimplementedServiceServer) Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Process not implemented")
|
||||
return status.Error(codes.Unimplemented, "method Process not implemented")
|
||||
}
|
||||
func (UnimplementedServiceServer) mustEmbedUnimplementedServiceServer() {}
|
||||
func (UnimplementedServiceServer) testEmbeddedByValue() {}
|
||||
@@ -82,7 +82,7 @@ type UnsafeServiceServer interface {
|
||||
}
|
||||
|
||||
func RegisterServiceServer(s grpc.ServiceRegistrar, srv ServiceServer) {
|
||||
// If the following call pancis, it indicates UnimplementedServiceServer was
|
||||
// If the following call panics, it indicates UnimplementedServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
|
||||
+62
-44
@@ -4,23 +4,26 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
context "context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/google/go-sev-guest/client"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/health"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "agent"
|
||||
defSvcGRPCPort = "7002"
|
||||
svcName = "agent"
|
||||
defSvcGRPCSocket = "/run/cocos/agent.sock"
|
||||
)
|
||||
|
||||
type AgentServer interface {
|
||||
@@ -29,63 +32,76 @@ type AgentServer interface {
|
||||
}
|
||||
|
||||
type agentServer struct {
|
||||
gs server.Server
|
||||
mu sync.Mutex
|
||||
gs *grpc.Server
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
host string
|
||||
qp client.LeveledQuoteProvider
|
||||
caUrl string
|
||||
cvmId string
|
||||
}
|
||||
|
||||
func NewServer(logger *slog.Logger, svc agent.Service, host string, qp client.LeveledQuoteProvider, caUrl string, cvmId string) AgentServer {
|
||||
func NewServer(logger *slog.Logger, svc agent.Service, host string) AgentServer {
|
||||
return &agentServer{
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
qp: qp,
|
||||
caUrl: caUrl,
|
||||
cvmId: cvmId,
|
||||
}
|
||||
}
|
||||
|
||||
func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error {
|
||||
if cfg.Port == "" {
|
||||
cfg.Port = defSvcGRPCPort
|
||||
}
|
||||
|
||||
agentGrpcServerConfig := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: as.host,
|
||||
Port: cfg.Port,
|
||||
CertFile: cfg.CertFile,
|
||||
KeyFile: cfg.KeyFile,
|
||||
ServerCAFile: cfg.ServerCAFile,
|
||||
ClientCAFile: cfg.ClientCAFile,
|
||||
},
|
||||
},
|
||||
AttestedTLS: cfg.AttestedTls,
|
||||
}
|
||||
|
||||
registerAgentServiceServer := func(srv *grpc.Server) {
|
||||
reflection.Register(srv)
|
||||
agent.RegisterAgentServiceServer(srv, agentgrpc.NewServer(as.svc))
|
||||
}
|
||||
|
||||
authSvc, err := auth.New(cmp)
|
||||
if err != nil {
|
||||
as.logger.WithGroup(cmp.ID).Error(fmt.Sprintf("failed to create auth service %s", err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
grpcServerOptions := []grpc.ServerOption{
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
}
|
||||
|
||||
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, as.qp, authSvc, as.caUrl, as.cvmId)
|
||||
// Add authentication interceptors
|
||||
unary, stream := agentgrpc.NewAuthInterceptor(authSvc)
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary))
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
|
||||
|
||||
// Internal Unix socket is pure plaintext HTTP/2; Ingress Proxy handles external aTLS termination
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.Creds(insecure.NewCredentials()))
|
||||
|
||||
as.mu.Lock()
|
||||
as.gs = grpc.NewServer(grpcServerOptions...)
|
||||
gs := as.gs
|
||||
as.mu.Unlock()
|
||||
|
||||
reflection.Register(gs)
|
||||
agent.RegisterAgentServiceServer(gs, agentgrpc.NewServer(as.svc))
|
||||
|
||||
healthServer := health.NewServer()
|
||||
healthServer.SetServingStatus("agent", grpc_health_v1.HealthCheckResponse_SERVING)
|
||||
grpc_health_v1.RegisterHealthServer(gs, healthServer)
|
||||
|
||||
socketPath := as.host
|
||||
if socketPath == "" || socketPath == "0.0.0.0" {
|
||||
socketPath = defSvcGRPCSocket
|
||||
}
|
||||
|
||||
var listener net.Listener
|
||||
if socketPath[0] == '/' || socketPath[0] == '.' {
|
||||
// Remove existing socket file if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
listener, err = net.Listen("unix", socketPath)
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", socketPath)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
as.logger.Error(fmt.Sprintf("failed to listen on %s: %s", socketPath, err))
|
||||
return err
|
||||
}
|
||||
|
||||
as.logger.Info(fmt.Sprintf("agent service gRPC server listening at %s without TLS", socketPath))
|
||||
|
||||
go func() {
|
||||
err := as.gs.Start()
|
||||
if err != nil {
|
||||
err := gs.Serve(listener)
|
||||
if err != nil && err != grpc.ErrServerStopped {
|
||||
as.logger.Error(fmt.Sprintf("failed to start grpc server %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
@@ -94,8 +110,10 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
|
||||
}
|
||||
|
||||
func (as *agentServer) Stop() error {
|
||||
if as.gs == nil {
|
||||
return nil
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
if as.gs != nil {
|
||||
as.gs.GracefulStop()
|
||||
}
|
||||
return as.gs.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,518 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, []byte) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
mockSvc := new(mocks.Service)
|
||||
host := "localhost:0"
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.NoError(t, err, "Failed to generate ECDSA key")
|
||||
|
||||
pubkey, err := x509.MarshalPKIXPublicKey(privateKey.Public())
|
||||
assert.NoError(t, err, "Failed to marshal public key")
|
||||
|
||||
return logger, mockSvc, host, pubkey
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
logger, svc, host, _ := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
host string
|
||||
expected AgentServer
|
||||
}{
|
||||
{
|
||||
name: "valid server creation",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
},
|
||||
{
|
||||
name: "server with empty host",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: "",
|
||||
},
|
||||
{
|
||||
name: "server with empty caUrl",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
},
|
||||
{
|
||||
name: "server with empty cvmId",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(tt.logger, tt.svc, tt.host)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
|
||||
agentSrv, ok := server.(*agentServer)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.logger, agentSrv.logger)
|
||||
assert.Equal(t, tt.svc, agentSrv.svc)
|
||||
assert.Equal(t, tt.host, agentSrv.host)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentServer_Start(t *testing.T) {
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg agent.AgentConfig
|
||||
cmp agent.Computation
|
||||
setupMocks func(*mocks.Service)
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful start with default port",
|
||||
cfg: agent.AgentConfig{
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
ClientCAFile: "client-ca.pem",
|
||||
AttestedTls: true,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "test-computation-1",
|
||||
Name: "Test Computation",
|
||||
Description: "A test computation",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x01, 0x02, 0x03},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x04, 0x05, 0x06},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupMocks: func(m *mocks.Service) {
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "successful start with custom port",
|
||||
cfg: agent.AgentConfig{
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
ClientCAFile: "client-ca.pem",
|
||||
AttestedTls: false,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "test-computation-2",
|
||||
Name: "Test Computation 2",
|
||||
Description: "Another test computation",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x07, 0x08, 0x09},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x0a, 0x0b, 0x0c},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupMocks: func(m *mocks.Service) {
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "start with minimal config",
|
||||
cfg: agent.AgentConfig{
|
||||
AttestedTls: false,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "test-computation-3",
|
||||
Name: "Minimal Test",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x0d, 0x0e, 0x0f},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x10, 0x11, 0x12},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupMocks: func(m *mocks.Service) {
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMocks(svc)
|
||||
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
err := server.Start(tt.cfg, tt.cmp)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify the port was set correctly
|
||||
agentSrv := server.(*agentServer)
|
||||
assert.NotNil(t, agentSrv.gs)
|
||||
|
||||
if err := server.Stop(); err != nil {
|
||||
t.Fatalf("Failed to stop server after start: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentServer_Stop(t *testing.T) {
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupServer func(AgentServer) error
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "stop unstarted server",
|
||||
setupServer: func(server AgentServer) error {
|
||||
// Don't start the server
|
||||
return nil
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "stop started server",
|
||||
setupServer: func(server AgentServer) error {
|
||||
cfg := agent.AgentConfig{}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-stop-computation",
|
||||
Name: "Stop Test",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x19, 0x1a, 0x1b},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x1c, 0x1d, 0x1e},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
return server.Start(cfg, cmp)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
err := tt.setupServer(server)
|
||||
if err != nil {
|
||||
t.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
|
||||
// Give the server a moment to start if it was started
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err = server.Stop()
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentServer_StopMultipleTimes(t *testing.T) {
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
// Start the server
|
||||
cfg := agent.AgentConfig{}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-multiple-stop",
|
||||
Name: "Multiple Stop Test",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x1f, 0x20, 0x21},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x22, 0x23, 0x24},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.Start(cfg, cmp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Give the server a moment to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Stop the server multiple times
|
||||
err1 := server.Stop()
|
||||
err2 := server.Stop()
|
||||
err3 := server.Stop()
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NoError(t, err2)
|
||||
assert.NoError(t, err3)
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
cfg := agent.AgentConfig{}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-restart",
|
||||
Name: "Restart Test",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x25, 0x26, 0x27},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x28, 0x29, 0x2a},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start, stop, then start again
|
||||
err := server.Start(cfg, cmp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start again with different config
|
||||
cfg2 := agent.AgentConfig{}
|
||||
cmp2 := agent.Computation{
|
||||
ID: "test-restart-2",
|
||||
Name: "Restart Test 2",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x2b, 0x2c, 0x2d},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x2e, 0x2f, 0x30},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = server.Start(cfg2, cmp2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config agent.AgentConfig
|
||||
cmp agent.Computation
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "valid config with all fields",
|
||||
config: agent.AgentConfig{
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
ClientCAFile: "client-ca.pem",
|
||||
AttestedTls: true,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "valid-config-test",
|
||||
Name: "Valid Config Test",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x31, 0x32, 0x33},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x34, 0x35, 0x36},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid config with minimal fields",
|
||||
config: agent.AgentConfig{},
|
||||
cmp: agent.Computation{
|
||||
ID: "minimal-config-test",
|
||||
Name: "Minimal Config Test",
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x37, 0x38, 0x39},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x3a, 0x3b, 0x3c},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "config with empty port uses default",
|
||||
config: agent.AgentConfig{},
|
||||
cmp: agent.Computation{
|
||||
ID: "default-port-test",
|
||||
Name: "Default Port Test",
|
||||
Algorithm: &agent.Algorithm{Hash: [32]byte{0x3d, 0x3e, 0x3f}, UserKey: pubKey},
|
||||
Datasets: []agent.Dataset{
|
||||
{Hash: [32]byte{0x40, 0x41, 0x42}, UserKey: pubKey},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{UserKey: pubKey},
|
||||
},
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
err := server.Start(tt.config, tt.cmp)
|
||||
|
||||
if tt.valid {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify server started successfully
|
||||
agentSrv := server.(*agentServer)
|
||||
assert.NotNil(t, agentSrv.gs)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if err := server.Stop(); err != nil {
|
||||
t.Fatalf("Failed to stop server after start: %v", err)
|
||||
}
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstants(t *testing.T) {
|
||||
assert.Equal(t, "agent", svcName)
|
||||
assert.Equal(t, "/run/cocos/agent.sock", defSvcGRPCSocket)
|
||||
}
|
||||
@@ -1,15 +1,31 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
)
|
||||
|
||||
// NewAgentServer creates a new instance of AgentServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentServer(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentServer {
|
||||
mock := &AgentServer{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// AgentServer is an autogenerated mock type for the AgentServer type
|
||||
type AgentServer struct {
|
||||
mock.Mock
|
||||
@@ -23,21 +39,20 @@ func (_m *AgentServer) EXPECT() *AgentServer_Expecter {
|
||||
return &AgentServer_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields: cfg, cmp
|
||||
func (_m *AgentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error {
|
||||
ret := _m.Called(cfg, cmp)
|
||||
// Start provides a mock function for the type AgentServer
|
||||
func (_mock *AgentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error {
|
||||
ret := _mock.Called(cfg, cmp)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Start")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(agent.AgentConfig, agent.Computation) error); ok {
|
||||
r0 = rf(cfg, cmp)
|
||||
if returnFunc, ok := ret.Get(0).(func(agent.AgentConfig, agent.Computation) error); ok {
|
||||
r0 = returnFunc(cfg, cmp)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -55,36 +70,46 @@ func (_e *AgentServer_Expecter) Start(cfg interface{}, cmp interface{}) *AgentSe
|
||||
|
||||
func (_c *AgentServer_Start_Call) Run(run func(cfg agent.AgentConfig, cmp agent.Computation)) *AgentServer_Start_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(agent.AgentConfig), args[1].(agent.Computation))
|
||||
var arg0 agent.AgentConfig
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(agent.AgentConfig)
|
||||
}
|
||||
var arg1 agent.Computation
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(agent.Computation)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentServer_Start_Call) Return(_a0 error) *AgentServer_Start_Call {
|
||||
_c.Call.Return(_a0)
|
||||
func (_c *AgentServer_Start_Call) Return(err error) *AgentServer_Start_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentServer_Start_Call) RunAndReturn(run func(agent.AgentConfig, agent.Computation) error) *AgentServer_Start_Call {
|
||||
func (_c *AgentServer_Start_Call) RunAndReturn(run func(cfg agent.AgentConfig, cmp agent.Computation) error) *AgentServer_Start_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function with no fields
|
||||
func (_m *AgentServer) Stop() error {
|
||||
ret := _m.Called()
|
||||
// Stop provides a mock function for the type AgentServer
|
||||
func (_mock *AgentServer) Stop() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -105,8 +130,8 @@ func (_c *AgentServer_Stop_Call) Run(run func()) *AgentServer_Stop_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentServer_Stop_Call) Return(_a0 error) *AgentServer_Stop_Call {
|
||||
_c.Call.Return(_a0)
|
||||
func (_c *AgentServer_Stop_Call) Return(err error) *AgentServer_Stop_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -114,17 +139,3 @@ func (_c *AgentServer_Stop_Call) RunAndReturn(run func() error) *AgentServer_Sto
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewAgentServer creates a new instance of AgentServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentServer(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentServer {
|
||||
mock := &AgentServer{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
+28
-42
@@ -3,8 +3,8 @@
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.5
|
||||
// protoc v5.29.0
|
||||
// protoc-gen-go v1.36.11
|
||||
// protoc v6.33.1
|
||||
// source: agent/events/events.proto
|
||||
|
||||
package events
|
||||
@@ -261,46 +261,32 @@ func (*EventsLogs_AgentEvent) isEventsLogs_Message() {}
|
||||
|
||||
var File_agent_events_events_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_agent_events_events_proto_rawDesc = string([]byte{
|
||||
0x0a, 0x19, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x65,
|
||||
0x76, 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x65, 0x76, 0x65,
|
||||
0x6e, 0x74, 0x73, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x22, 0xde, 0x01, 0x0a, 0x0a, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x45, 0x76,
|
||||
0x65, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70,
|
||||
0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x54, 0x79,
|
||||
0x70, 0x65, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18,
|
||||
0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d,
|
||||
0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x25, 0x0a, 0x0e,
|
||||
0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x03,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f,
|
||||
0x6e, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04,
|
||||
0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1e, 0x0a,
|
||||
0x0a, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x0a, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x16, 0x0a,
|
||||
0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73,
|
||||
0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x9b, 0x01, 0x0a, 0x08, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x4c,
|
||||
0x6f, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x25, 0x0a, 0x0e,
|
||||
0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x02,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f,
|
||||
0x6e, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d,
|
||||
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67,
|
||||
0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54,
|
||||
0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
|
||||
0x61, 0x6d, 0x70, 0x22, 0x7f, 0x0a, 0x0a, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x4c, 0x6f, 0x67,
|
||||
0x73, 0x12, 0x2f, 0x0a, 0x09, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x6c, 0x6f, 0x67, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x41, 0x67,
|
||||
0x65, 0x6e, 0x74, 0x4c, 0x6f, 0x67, 0x48, 0x00, 0x52, 0x08, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x4c,
|
||||
0x6f, 0x67, 0x12, 0x35, 0x0a, 0x0b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x65, 0x76, 0x65, 0x6e,
|
||||
0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73,
|
||||
0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x61,
|
||||
0x67, 0x65, 0x6e, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73,
|
||||
0x73, 0x61, 0x67, 0x65, 0x42, 0x0a, 0x5a, 0x08, 0x2e, 0x2f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73,
|
||||
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
})
|
||||
const file_agent_events_events_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x19agent/events/events.proto\x12\x06events\x1a\x1fgoogle/protobuf/timestamp.proto\"\xde\x01\n" +
|
||||
"\n" +
|
||||
"AgentEvent\x12\x1d\n" +
|
||||
"\n" +
|
||||
"event_type\x18\x01 \x01(\tR\teventType\x128\n" +
|
||||
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x03 \x01(\tR\rcomputationId\x12\x18\n" +
|
||||
"\adetails\x18\x04 \x01(\fR\adetails\x12\x1e\n" +
|
||||
"\n" +
|
||||
"originator\x18\x05 \x01(\tR\n" +
|
||||
"originator\x12\x16\n" +
|
||||
"\x06status\x18\x06 \x01(\tR\x06status\"\x9b\x01\n" +
|
||||
"\bAgentLog\x12\x18\n" +
|
||||
"\amessage\x18\x01 \x01(\tR\amessage\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x02 \x01(\tR\rcomputationId\x12\x14\n" +
|
||||
"\x05level\x18\x03 \x01(\tR\x05level\x128\n" +
|
||||
"\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\x7f\n" +
|
||||
"\n" +
|
||||
"EventsLogs\x12/\n" +
|
||||
"\tagent_log\x18\x01 \x01(\v2\x10.events.AgentLogH\x00R\bagentLog\x125\n" +
|
||||
"\vagent_event\x18\x02 \x01(\v2\x12.events.AgentEventH\x00R\n" +
|
||||
"agentEventB\t\n" +
|
||||
"\amessageB\n" +
|
||||
"Z\b./eventsb\x06proto3"
|
||||
|
||||
var (
|
||||
file_agent_events_events_proto_rawDescOnce sync.Once
|
||||
|
||||
@@ -1,16 +1,32 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
json "encoding/json"
|
||||
"encoding/json"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
@@ -24,9 +40,10 @@ func (_m *Service) EXPECT() *Service_Expecter {
|
||||
return &Service_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// SendEvent provides a mock function with given fields: cmpID, event, status, details
|
||||
func (_m *Service) SendEvent(cmpID string, event string, status string, details json.RawMessage) {
|
||||
_m.Called(cmpID, event, status, details)
|
||||
// SendEvent provides a mock function for the type Service
|
||||
func (_mock *Service) SendEvent(cmpID string, event string, status string, details json.RawMessage) {
|
||||
_mock.Called(cmpID, event, status, details)
|
||||
return
|
||||
}
|
||||
|
||||
// Service_SendEvent_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendEvent'
|
||||
@@ -45,7 +62,28 @@ func (_e *Service_Expecter) SendEvent(cmpID interface{}, event interface{}, stat
|
||||
|
||||
func (_c *Service_SendEvent_Call) Run(run func(cmpID string, event string, status string, details json.RawMessage)) *Service_SendEvent_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string), args[1].(string), args[2].(string), args[3].(json.RawMessage))
|
||||
var arg0 string
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(string)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
var arg3 json.RawMessage
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(json.RawMessage)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@@ -55,21 +93,7 @@ func (_c *Service_SendEvent_Call) Return() *Service_SendEvent_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_SendEvent_Call) RunAndReturn(run func(string, string, string, json.RawMessage)) *Service_SendEvent_Call {
|
||||
func (_c *Service_SendEvent_Call) RunAndReturn(run func(cmpID string, event string, status string, details json.RawMessage)) *Service_SendEvent_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.11
|
||||
// protoc v6.33.1
|
||||
// source: agent/log/log.proto
|
||||
|
||||
package log
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type LogEntry struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
|
||||
ComputationId string `protobuf:"bytes,2,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"`
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *LogEntry) Reset() {
|
||||
*x = LogEntry{}
|
||||
mi := &file_agent_log_log_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *LogEntry) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*LogEntry) ProtoMessage() {}
|
||||
|
||||
func (x *LogEntry) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_log_log_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use LogEntry.ProtoReflect.Descriptor instead.
|
||||
func (*LogEntry) Descriptor() ([]byte, []int) {
|
||||
return file_agent_log_log_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *LogEntry) GetMessage() string {
|
||||
if x != nil {
|
||||
return x.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LogEntry) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LogEntry) GetLevel() string {
|
||||
if x != nil {
|
||||
return x.Level
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LogEntry) GetTimestamp() *timestamppb.Timestamp {
|
||||
if x != nil {
|
||||
return x.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type EventEntry struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"`
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
ComputationId string `protobuf:"bytes,3,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Details []byte `protobuf:"bytes,4,opt,name=details,proto3" json:"details,omitempty"` // JSON payload
|
||||
Originator string `protobuf:"bytes,5,opt,name=originator,proto3" json:"originator,omitempty"`
|
||||
Status string `protobuf:"bytes,6,opt,name=status,proto3" json:"status,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *EventEntry) Reset() {
|
||||
*x = EventEntry{}
|
||||
mi := &file_agent_log_log_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *EventEntry) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*EventEntry) ProtoMessage() {}
|
||||
|
||||
func (x *EventEntry) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_log_log_proto_msgTypes[1]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use EventEntry.ProtoReflect.Descriptor instead.
|
||||
func (*EventEntry) Descriptor() ([]byte, []int) {
|
||||
return file_agent_log_log_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetEventType() string {
|
||||
if x != nil {
|
||||
return x.EventType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetTimestamp() *timestamppb.Timestamp {
|
||||
if x != nil {
|
||||
return x.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetDetails() []byte {
|
||||
if x != nil {
|
||||
return x.Details
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetOriginator() string {
|
||||
if x != nil {
|
||||
return x.Originator
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetStatus() string {
|
||||
if x != nil {
|
||||
return x.Status
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_agent_log_log_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_agent_log_log_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x13agent/log/log.proto\x12\x03log\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1bgoogle/protobuf/empty.proto\"\x9b\x01\n" +
|
||||
"\bLogEntry\x12\x18\n" +
|
||||
"\amessage\x18\x01 \x01(\tR\amessage\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x02 \x01(\tR\rcomputationId\x12\x14\n" +
|
||||
"\x05level\x18\x03 \x01(\tR\x05level\x128\n" +
|
||||
"\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\xde\x01\n" +
|
||||
"\n" +
|
||||
"EventEntry\x12\x1d\n" +
|
||||
"\n" +
|
||||
"event_type\x18\x01 \x01(\tR\teventType\x128\n" +
|
||||
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x03 \x01(\tR\rcomputationId\x12\x18\n" +
|
||||
"\adetails\x18\x04 \x01(\fR\adetails\x12\x1e\n" +
|
||||
"\n" +
|
||||
"originator\x18\x05 \x01(\tR\n" +
|
||||
"originator\x12\x16\n" +
|
||||
"\x06status\x18\x06 \x01(\tR\x06status2v\n" +
|
||||
"\fLogCollector\x120\n" +
|
||||
"\aSendLog\x12\r.log.LogEntry\x1a\x16.google.protobuf.Empty\x124\n" +
|
||||
"\tSendEvent\x12\x0f.log.EventEntry\x1a\x16.google.protobuf.EmptyB\aZ\x05./logb\x06proto3"
|
||||
|
||||
var (
|
||||
file_agent_log_log_proto_rawDescOnce sync.Once
|
||||
file_agent_log_log_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_agent_log_log_proto_rawDescGZIP() []byte {
|
||||
file_agent_log_log_proto_rawDescOnce.Do(func() {
|
||||
file_agent_log_log_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_log_log_proto_rawDesc), len(file_agent_log_log_proto_rawDesc)))
|
||||
})
|
||||
return file_agent_log_log_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_log_log_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_agent_log_log_proto_goTypes = []any{
|
||||
(*LogEntry)(nil), // 0: log.LogEntry
|
||||
(*EventEntry)(nil), // 1: log.EventEntry
|
||||
(*timestamppb.Timestamp)(nil), // 2: google.protobuf.Timestamp
|
||||
(*emptypb.Empty)(nil), // 3: google.protobuf.Empty
|
||||
}
|
||||
var file_agent_log_log_proto_depIdxs = []int32{
|
||||
2, // 0: log.LogEntry.timestamp:type_name -> google.protobuf.Timestamp
|
||||
2, // 1: log.EventEntry.timestamp:type_name -> google.protobuf.Timestamp
|
||||
0, // 2: log.LogCollector.SendLog:input_type -> log.LogEntry
|
||||
1, // 3: log.LogCollector.SendEvent:input_type -> log.EventEntry
|
||||
3, // 4: log.LogCollector.SendLog:output_type -> google.protobuf.Empty
|
||||
3, // 5: log.LogCollector.SendEvent:output_type -> google.protobuf.Empty
|
||||
4, // [4:6] is the sub-list for method output_type
|
||||
2, // [2:4] is the sub-list for method input_type
|
||||
2, // [2:2] is the sub-list for extension type_name
|
||||
2, // [2:2] is the sub-list for extension extendee
|
||||
0, // [0:2] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_log_log_proto_init() }
|
||||
func file_agent_log_log_proto_init() {
|
||||
if File_agent_log_log_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_log_log_proto_rawDesc), len(file_agent_log_log_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_agent_log_log_proto_goTypes,
|
||||
DependencyIndexes: file_agent_log_log_proto_depIdxs,
|
||||
MessageInfos: file_agent_log_log_proto_msgTypes,
|
||||
}.Build()
|
||||
File_agent_log_log_proto = out.File
|
||||
file_agent_log_log_proto_goTypes = nil
|
||||
file_agent_log_log_proto_depIdxs = nil
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package log;
|
||||
|
||||
option go_package = "./log";
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
service LogCollector {
|
||||
rpc SendLog(LogEntry) returns (google.protobuf.Empty);
|
||||
rpc SendEvent(EventEntry) returns (google.protobuf.Empty);
|
||||
}
|
||||
|
||||
message LogEntry {
|
||||
string message = 1;
|
||||
string computation_id = 2;
|
||||
string level = 3;
|
||||
google.protobuf.Timestamp timestamp = 4;
|
||||
}
|
||||
|
||||
message EventEntry {
|
||||
string event_type = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
string computation_id = 3;
|
||||
bytes details = 4; // JSON payload
|
||||
string originator = 5;
|
||||
string status = 6;
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc v6.33.1
|
||||
// source: agent/log/log.proto
|
||||
|
||||
package log
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
LogCollector_SendLog_FullMethodName = "/log.LogCollector/SendLog"
|
||||
LogCollector_SendEvent_FullMethodName = "/log.LogCollector/SendEvent"
|
||||
)
|
||||
|
||||
// LogCollectorClient is the client API for LogCollector service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type LogCollectorClient interface {
|
||||
SendLog(ctx context.Context, in *LogEntry, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
SendEvent(ctx context.Context, in *EventEntry, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
}
|
||||
|
||||
type logCollectorClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewLogCollectorClient(cc grpc.ClientConnInterface) LogCollectorClient {
|
||||
return &logCollectorClient{cc}
|
||||
}
|
||||
|
||||
func (c *logCollectorClient) SendLog(ctx context.Context, in *LogEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, LogCollector_SendLog_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *logCollectorClient) SendEvent(ctx context.Context, in *EventEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, LogCollector_SendEvent_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LogCollectorServer is the server API for LogCollector service.
|
||||
// All implementations must embed UnimplementedLogCollectorServer
|
||||
// for forward compatibility.
|
||||
type LogCollectorServer interface {
|
||||
SendLog(context.Context, *LogEntry) (*emptypb.Empty, error)
|
||||
SendEvent(context.Context, *EventEntry) (*emptypb.Empty, error)
|
||||
mustEmbedUnimplementedLogCollectorServer()
|
||||
}
|
||||
|
||||
// UnimplementedLogCollectorServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedLogCollectorServer struct{}
|
||||
|
||||
func (UnimplementedLogCollectorServer) SendLog(context.Context, *LogEntry) (*emptypb.Empty, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method SendLog not implemented")
|
||||
}
|
||||
func (UnimplementedLogCollectorServer) SendEvent(context.Context, *EventEntry) (*emptypb.Empty, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method SendEvent not implemented")
|
||||
}
|
||||
func (UnimplementedLogCollectorServer) mustEmbedUnimplementedLogCollectorServer() {}
|
||||
func (UnimplementedLogCollectorServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeLogCollectorServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to LogCollectorServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeLogCollectorServer interface {
|
||||
mustEmbedUnimplementedLogCollectorServer()
|
||||
}
|
||||
|
||||
func RegisterLogCollectorServer(s grpc.ServiceRegistrar, srv LogCollectorServer) {
|
||||
// If the following call panics, it indicates UnimplementedLogCollectorServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&LogCollector_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _LogCollector_SendLog_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(LogEntry)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(LogCollectorServer).SendLog(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: LogCollector_SendLog_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(LogCollectorServer).SendLog(ctx, req.(*LogEntry))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _LogCollector_SendEvent_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EventEntry)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(LogCollectorServer).SendEvent(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: LogCollector_SendEvent_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(LogCollectorServer).SendEvent(ctx, req.(*EventEntry))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// LogCollector_ServiceDesc is the grpc.ServiceDesc for LogCollector service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var LogCollector_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "log.LogCollector",
|
||||
HandlerType: (*LogCollectorServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "SendLog",
|
||||
Handler: _LogCollector_SendLog_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SendEvent",
|
||||
Handler: _LogCollector_SendEvent_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "agent/log/log.proto",
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/log"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
var _ log.LogCollectorServer = (*LogForwarder)(nil)
|
||||
|
||||
type LogForwarder struct {
|
||||
log.UnimplementedLogCollectorServer
|
||||
cvmsClient cvms.ServiceClient
|
||||
logger *slog.Logger
|
||||
logQueue chan *cvms.ClientStreamMessage
|
||||
}
|
||||
|
||||
func New(logger *slog.Logger, cvmsClient cvms.ServiceClient, queue chan *cvms.ClientStreamMessage) *LogForwarder {
|
||||
return &LogForwarder{
|
||||
cvmsClient: cvmsClient,
|
||||
logger: logger,
|
||||
logQueue: queue,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LogForwarder) SendLog(ctx context.Context, req *log.LogEntry) (*emptypb.Empty, error) {
|
||||
s.logQueue <- &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &cvms.AgentLog{
|
||||
Message: req.Message,
|
||||
ComputationId: req.ComputationId,
|
||||
Level: req.Level,
|
||||
Timestamp: req.Timestamp,
|
||||
},
|
||||
},
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *LogForwarder) SendEvent(ctx context.Context, req *log.EventEntry) (*emptypb.Empty, error) {
|
||||
s.logQueue <- &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &cvms.AgentEvent{
|
||||
EventType: req.EventType,
|
||||
Timestamp: req.Timestamp,
|
||||
ComputationId: req.ComputationId,
|
||||
Details: req.Details,
|
||||
Originator: req.Originator,
|
||||
Status: req.Status,
|
||||
},
|
||||
},
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// TestNewLogForwarder tests the creation of a new log forwarder.
|
||||
func TestNewLogForwarder(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
require.NotNil(t, lf)
|
||||
assert.NotNil(t, lf.logger)
|
||||
assert.Nil(t, lf.cvmsClient)
|
||||
assert.NotNil(t, lf.logQueue)
|
||||
}
|
||||
|
||||
// TestSendLog tests sending a log entry.
|
||||
func TestSendLog(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
req := &log.LogEntry{
|
||||
Message: "Test log message",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify message was queued
|
||||
select {
|
||||
case msg := <-queue:
|
||||
require.NotNil(t, msg)
|
||||
agentLog := msg.GetAgentLog()
|
||||
assert.NotNil(t, agentLog)
|
||||
assert.Equal(t, "Test log message", agentLog.Message)
|
||||
assert.Equal(t, "computation-1", agentLog.ComputationId)
|
||||
assert.Equal(t, "INFO", agentLog.Level)
|
||||
default:
|
||||
t.Fatal("No message in queue")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendEvent tests sending an event entry.
|
||||
func TestSendEvent(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
details, err := json.Marshal(map[string]string{"key": "value"})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &log.EventEntry{
|
||||
EventType: "COMPUTATION_STARTED",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
ComputationId: "computation-1",
|
||||
Details: details,
|
||||
Originator: "runner",
|
||||
Status: "SUCCESS",
|
||||
}
|
||||
|
||||
resp, err := lf.SendEvent(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify message was queued
|
||||
select {
|
||||
case msg := <-queue:
|
||||
require.NotNil(t, msg)
|
||||
agentEvent := msg.GetAgentEvent()
|
||||
assert.NotNil(t, agentEvent)
|
||||
assert.Equal(t, "COMPUTATION_STARTED", agentEvent.EventType)
|
||||
assert.Equal(t, "computation-1", agentEvent.ComputationId)
|
||||
assert.Equal(t, "runner", agentEvent.Originator)
|
||||
assert.Equal(t, "SUCCESS", agentEvent.Status)
|
||||
default:
|
||||
t.Fatal("No message in queue")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendMultipleLogs tests sending multiple log entries.
|
||||
func TestSendMultipleLogs(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
req := &log.LogEntry{
|
||||
Message: "Log message",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
assert.Equal(t, 5, len(queue))
|
||||
}
|
||||
|
||||
// TestSendEventWithVariousTypes tests sending events with different types.
|
||||
func TestSendEventWithVariousTypes(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
eventTypes := []string{"STARTED", "RUNNING", "COMPLETED", "FAILED"}
|
||||
for _, eventType := range eventTypes {
|
||||
details, err := json.Marshal(map[string]string{"type": eventType})
|
||||
require.NoError(t, err)
|
||||
req := &log.EventEntry{
|
||||
EventType: eventType,
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
ComputationId: "computation-1",
|
||||
Details: details,
|
||||
Originator: "runner",
|
||||
Status: "OK",
|
||||
}
|
||||
|
||||
resp, err := lf.SendEvent(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
assert.Equal(t, 4, len(queue))
|
||||
}
|
||||
|
||||
// TestSendLogWithEmptyMessage tests sending log with empty message.
|
||||
func TestSendLogWithEmptyMessage(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
req := &log.LogEntry{
|
||||
Message: "",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
select {
|
||||
case msg := <-queue:
|
||||
agentLog := msg.GetAgentLog()
|
||||
assert.Equal(t, "", agentLog.Message)
|
||||
default:
|
||||
t.Fatal("No message in queue")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendEventWithNilDetails tests sending event with nil details.
|
||||
func TestSendEventWithNilDetails(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
req := &log.EventEntry{
|
||||
EventType: "TEST_EVENT",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
ComputationId: "computation-1",
|
||||
Details: nil,
|
||||
Originator: "test",
|
||||
Status: "OK",
|
||||
}
|
||||
|
||||
resp, err := lf.SendEvent(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
select {
|
||||
case msg := <-queue:
|
||||
agentEvent := msg.GetAgentEvent()
|
||||
assert.Nil(t, agentEvent.Details)
|
||||
default:
|
||||
t.Fatal("No message in queue")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendLogWithVariousLevels tests sending logs with various severity levels.
|
||||
func TestSendLogWithVariousLevels(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
levels := []string{"DEBUG", "INFO", "WARN", "ERROR"}
|
||||
for _, level := range levels {
|
||||
req := &log.LogEntry{
|
||||
Message: "Test " + level,
|
||||
ComputationId: "computation-1",
|
||||
Level: level,
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
assert.Equal(t, 4, len(queue))
|
||||
}
|
||||
|
||||
// TestSendLogWithDifferentComputationIds tests sending logs with different computation IDs.
|
||||
func TestSendLogWithDifferentComputationIds(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
req := &log.LogEntry{
|
||||
Message: "Message",
|
||||
ComputationId: "computation-" + string(rune(48+i)),
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, len(queue))
|
||||
}
|
||||
|
||||
// TestQueueBehavior tests that queue is properly used.
|
||||
func TestQueueBehavior(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 1)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
req := &log.LogEntry{
|
||||
Message: "Test",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, 1, len(queue))
|
||||
}
|
||||
|
||||
// TestConcurrentSendLog tests concurrent log sending.
|
||||
func TestConcurrentSendLog(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
req := &log.LogEntry{
|
||||
Message: "Concurrent log",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
_, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Give goroutines time to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Should have received all messages
|
||||
assert.True(t, len(queue) > 0)
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
type MockAttestationClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAttestationClient) GetAttestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
args := m.Called(ctx, reportData, nonce, attType)
|
||||
return args.Get(0).([]byte), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAttestationClient) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
args := m.Called(ctx, reportData, nonce, attType)
|
||||
return args.Get(0).([]byte), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
|
||||
args := m.Called(ctx, nonce)
|
||||
return args.Get(0).([]byte), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAttestationClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
@@ -1,393 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
|
||||
context "context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Service_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Service) EXPECT() *Service_Expecter {
|
||||
return &Service_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Algo provides a mock function with given fields: ctx, algorithm
|
||||
func (_m *Service) Algo(ctx context.Context, algorithm agent.Algorithm) error {
|
||||
ret := _m.Called(ctx, algorithm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Algo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Algorithm) error); ok {
|
||||
r0 = rf(ctx, algorithm)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_Algo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Algo'
|
||||
type Service_Algo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Algo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - algorithm agent.Algorithm
|
||||
func (_e *Service_Expecter) Algo(ctx interface{}, algorithm interface{}) *Service_Algo_Call {
|
||||
return &Service_Algo_Call{Call: _e.mock.On("Algo", ctx, algorithm)}
|
||||
}
|
||||
|
||||
func (_c *Service_Algo_Call) Run(run func(ctx context.Context, algorithm agent.Algorithm)) *Service_Algo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(agent.Algorithm))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Algo_Call) Return(_a0 error) *Service_Algo_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Algo_Call) RunAndReturn(run func(context.Context, agent.Algorithm) error) *Service_Algo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Attestation provides a mock function with given fields: ctx, reportData, nonce, attType
|
||||
func (_m *Service) Attestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType config.AttestationType) ([]byte, error) {
|
||||
ret := _m.Called(ctx, reportData, nonce, attType)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Attestation")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, config.AttestationType) ([]byte, error)); ok {
|
||||
return rf(ctx, reportData, nonce, attType)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, config.AttestationType) []byte); ok {
|
||||
r0 = rf(ctx, reportData, nonce, attType)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, [64]byte, [32]byte, config.AttestationType) error); ok {
|
||||
r1 = rf(ctx, reportData, nonce, attType)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_Attestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Attestation'
|
||||
type Service_Attestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Attestation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - reportData [64]byte
|
||||
// - nonce [32]byte
|
||||
// - attType config.AttestationType
|
||||
func (_e *Service_Expecter) Attestation(ctx interface{}, reportData interface{}, nonce interface{}, attType interface{}) *Service_Attestation_Call {
|
||||
return &Service_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, nonce, attType)}
|
||||
}
|
||||
|
||||
func (_c *Service_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType config.AttestationType)) *Service_Attestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].([64]byte), args[2].([32]byte), args[3].(config.AttestationType))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Attestation_Call) Return(_a0 []byte, _a1 error) *Service_Attestation_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, [32]byte, config.AttestationType) ([]byte, error)) *Service_Attestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Data provides a mock function with given fields: ctx, dataset
|
||||
func (_m *Service) Data(ctx context.Context, dataset agent.Dataset) error {
|
||||
ret := _m.Called(ctx, dataset)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Data")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Dataset) error); ok {
|
||||
r0 = rf(ctx, dataset)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_Data_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Data'
|
||||
type Service_Data_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Data is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - dataset agent.Dataset
|
||||
func (_e *Service_Expecter) Data(ctx interface{}, dataset interface{}) *Service_Data_Call {
|
||||
return &Service_Data_Call{Call: _e.mock.On("Data", ctx, dataset)}
|
||||
}
|
||||
|
||||
func (_c *Service_Data_Call) Run(run func(ctx context.Context, dataset agent.Dataset)) *Service_Data_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(agent.Dataset))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Data_Call) Return(_a0 error) *Service_Data_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Data_Call) RunAndReturn(run func(context.Context, agent.Dataset) error) *Service_Data_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// InitComputation provides a mock function with given fields: ctx, cmp
|
||||
func (_m *Service) InitComputation(ctx context.Context, cmp agent.Computation) error {
|
||||
ret := _m.Called(ctx, cmp)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for InitComputation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Computation) error); ok {
|
||||
r0 = rf(ctx, cmp)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_InitComputation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitComputation'
|
||||
type Service_InitComputation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InitComputation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cmp agent.Computation
|
||||
func (_e *Service_Expecter) InitComputation(ctx interface{}, cmp interface{}) *Service_InitComputation_Call {
|
||||
return &Service_InitComputation_Call{Call: _e.mock.On("InitComputation", ctx, cmp)}
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) Run(run func(ctx context.Context, cmp agent.Computation)) *Service_InitComputation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(agent.Computation))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) Return(_a0 error) *Service_InitComputation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) RunAndReturn(run func(context.Context, agent.Computation) error) *Service_InitComputation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Result provides a mock function with given fields: ctx
|
||||
func (_m *Service) Result(ctx context.Context) ([]byte, error) {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Result")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) ([]byte, error)); ok {
|
||||
return rf(ctx)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context) []byte); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
|
||||
r1 = rf(ctx)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_Result_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Result'
|
||||
type Service_Result_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Result is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *Service_Expecter) Result(ctx interface{}) *Service_Result_Call {
|
||||
return &Service_Result_Call{Call: _e.mock.On("Result", ctx)}
|
||||
}
|
||||
|
||||
func (_c *Service_Result_Call) Run(run func(ctx context.Context)) *Service_Result_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Result_Call) Return(_a0 []byte, _a1 error) *Service_Result_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Result_Call) RunAndReturn(run func(context.Context) ([]byte, error)) *Service_Result_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// State provides a mock function with no fields
|
||||
func (_m *Service) State() string {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for State")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_State_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'State'
|
||||
type Service_State_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// State is a helper method to define mock.On call
|
||||
func (_e *Service_Expecter) State() *Service_State_Call {
|
||||
return &Service_State_Call{Call: _e.mock.On("State")}
|
||||
}
|
||||
|
||||
func (_c *Service_State_Call) Run(run func()) *Service_State_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_State_Call) Return(_a0 string) *Service_State_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_State_Call) RunAndReturn(run func() string) *Service_State_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// StopComputation provides a mock function with given fields: ctx
|
||||
func (_m *Service) StopComputation(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for StopComputation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_StopComputation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StopComputation'
|
||||
type Service_StopComputation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// StopComputation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *Service_Expecter) StopComputation(ctx interface{}) *Service_StopComputation_Call {
|
||||
return &Service_StopComputation_Call{Call: _e.mock.On("StopComputation", ctx)}
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) Run(run func(ctx context.Context)) *Service_StopComputation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) Return(_a0 error) *Service_StopComputation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) RunAndReturn(run func(context.Context) error) *Service_StopComputation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,434 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
|
||||
metadata "google.golang.org/grpc/metadata"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// AgentService_AlgoClient is an autogenerated mock type for the AgentService_AlgoClient type
|
||||
type AgentService_AlgoClient[Req interface{}, Res interface{}] struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type AgentService_AlgoClient_Expecter[Req interface{}, Res interface{}] struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) EXPECT() *AgentService_AlgoClient_Expecter[Req, Res] {
|
||||
return &AgentService_AlgoClient_Expecter[Req, Res]{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// CloseAndRecv provides a mock function with no fields
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) CloseAndRecv() (*agent.AlgoResponse, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseAndRecv")
|
||||
}
|
||||
|
||||
var r0 *agent.AlgoResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() (*agent.AlgoResponse, error)); ok {
|
||||
return rf()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() *agent.AlgoResponse); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*agent.AlgoResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_CloseAndRecv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseAndRecv'
|
||||
type AgentService_AlgoClient_CloseAndRecv_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseAndRecv is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter[Req, Res]) CloseAndRecv() *AgentService_AlgoClient_CloseAndRecv_Call[Req, Res] {
|
||||
return &AgentService_AlgoClient_CloseAndRecv_Call[Req, Res]{Call: _e.mock.On("CloseAndRecv")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseAndRecv_Call[Req, Res]) Run(run func()) *AgentService_AlgoClient_CloseAndRecv_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseAndRecv_Call[Req, Res]) Return(_a0 *agent.AlgoResponse, _a1 error) *AgentService_AlgoClient_CloseAndRecv_Call[Req, Res] {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseAndRecv_Call[Req, Res]) RunAndReturn(run func() (*agent.AlgoResponse, error)) *AgentService_AlgoClient_CloseAndRecv_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// CloseSend provides a mock function with no fields
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) CloseSend() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseSend")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_CloseSend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseSend'
|
||||
type AgentService_AlgoClient_CloseSend_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseSend is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter[Req, Res]) CloseSend() *AgentService_AlgoClient_CloseSend_Call[Req, Res] {
|
||||
return &AgentService_AlgoClient_CloseSend_Call[Req, Res]{Call: _e.mock.On("CloseSend")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseSend_Call[Req, Res]) Run(run func()) *AgentService_AlgoClient_CloseSend_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseSend_Call[Req, Res]) Return(_a0 error) *AgentService_AlgoClient_CloseSend_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseSend_Call[Req, Res]) RunAndReturn(run func() error) *AgentService_AlgoClient_CloseSend_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Context provides a mock function with no fields
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) Context() context.Context {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Context")
|
||||
}
|
||||
|
||||
var r0 context.Context
|
||||
if rf, ok := ret.Get(0).(func() context.Context); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(context.Context)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context'
|
||||
type AgentService_AlgoClient_Context_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Context is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter[Req, Res]) Context() *AgentService_AlgoClient_Context_Call[Req, Res] {
|
||||
return &AgentService_AlgoClient_Context_Call[Req, Res]{Call: _e.mock.On("Context")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Context_Call[Req, Res]) Run(run func()) *AgentService_AlgoClient_Context_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Context_Call[Req, Res]) Return(_a0 context.Context) *AgentService_AlgoClient_Context_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Context_Call[Req, Res]) RunAndReturn(run func() context.Context) *AgentService_AlgoClient_Context_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Header provides a mock function with no fields
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) Header() (metadata.MD, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Header")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() (metadata.MD, error)); ok {
|
||||
return rf()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_Header_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Header'
|
||||
type AgentService_AlgoClient_Header_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Header is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter[Req, Res]) Header() *AgentService_AlgoClient_Header_Call[Req, Res] {
|
||||
return &AgentService_AlgoClient_Header_Call[Req, Res]{Call: _e.mock.On("Header")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Header_Call[Req, Res]) Run(run func()) *AgentService_AlgoClient_Header_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Header_Call[Req, Res]) Return(_a0 metadata.MD, _a1 error) *AgentService_AlgoClient_Header_Call[Req, Res] {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Header_Call[Req, Res]) RunAndReturn(run func() (metadata.MD, error)) *AgentService_AlgoClient_Header_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RecvMsg provides a mock function with given fields: m
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) RecvMsg(m interface{}) error {
|
||||
ret := _m.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RecvMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg'
|
||||
type AgentService_AlgoClient_RecvMsg_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RecvMsg is a helper method to define mock.On call
|
||||
// - m interface{}
|
||||
func (_e *AgentService_AlgoClient_Expecter[Req, Res]) RecvMsg(m interface{}) *AgentService_AlgoClient_RecvMsg_Call[Req, Res] {
|
||||
return &AgentService_AlgoClient_RecvMsg_Call[Req, Res]{Call: _e.mock.On("RecvMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_RecvMsg_Call[Req, Res]) Run(run func(m interface{})) *AgentService_AlgoClient_RecvMsg_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(interface{}))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_RecvMsg_Call[Req, Res]) Return(_a0 error) *AgentService_AlgoClient_RecvMsg_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_RecvMsg_Call[Req, Res]) RunAndReturn(run func(interface{}) error) *AgentService_AlgoClient_RecvMsg_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Send provides a mock function with given fields: _a0
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) Send(_a0 *agent.AlgoRequest) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Send")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(*agent.AlgoRequest) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send'
|
||||
type AgentService_AlgoClient_Send_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Send is a helper method to define mock.On call
|
||||
// - _a0 *agent.AlgoRequest
|
||||
func (_e *AgentService_AlgoClient_Expecter[Req, Res]) Send(_a0 interface{}) *AgentService_AlgoClient_Send_Call[Req, Res] {
|
||||
return &AgentService_AlgoClient_Send_Call[Req, Res]{Call: _e.mock.On("Send", _a0)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Send_Call[Req, Res]) Run(run func(_a0 *agent.AlgoRequest)) *AgentService_AlgoClient_Send_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*agent.AlgoRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Send_Call[Req, Res]) Return(_a0 error) *AgentService_AlgoClient_Send_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Send_Call[Req, Res]) RunAndReturn(run func(*agent.AlgoRequest) error) *AgentService_AlgoClient_Send_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendMsg provides a mock function with given fields: m
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) SendMsg(m interface{}) error {
|
||||
ret := _m.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg'
|
||||
type AgentService_AlgoClient_SendMsg_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendMsg is a helper method to define mock.On call
|
||||
// - m interface{}
|
||||
func (_e *AgentService_AlgoClient_Expecter[Req, Res]) SendMsg(m interface{}) *AgentService_AlgoClient_SendMsg_Call[Req, Res] {
|
||||
return &AgentService_AlgoClient_SendMsg_Call[Req, Res]{Call: _e.mock.On("SendMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_SendMsg_Call[Req, Res]) Run(run func(m interface{})) *AgentService_AlgoClient_SendMsg_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(interface{}))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_SendMsg_Call[Req, Res]) Return(_a0 error) *AgentService_AlgoClient_SendMsg_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_SendMsg_Call[Req, Res]) RunAndReturn(run func(interface{}) error) *AgentService_AlgoClient_SendMsg_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Trailer provides a mock function with no fields
|
||||
func (_m *AgentService_AlgoClient[Req, Res]) Trailer() metadata.MD {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Trailer")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
if rf, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_Trailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Trailer'
|
||||
type AgentService_AlgoClient_Trailer_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Trailer is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter[Req, Res]) Trailer() *AgentService_AlgoClient_Trailer_Call[Req, Res] {
|
||||
return &AgentService_AlgoClient_Trailer_Call[Req, Res]{Call: _e.mock.On("Trailer")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Trailer_Call[Req, Res]) Run(run func()) *AgentService_AlgoClient_Trailer_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Trailer_Call[Req, Res]) Return(_a0 metadata.MD) *AgentService_AlgoClient_Trailer_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Trailer_Call[Req, Res]) RunAndReturn(run func() metadata.MD) *AgentService_AlgoClient_Trailer_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewAgentService_AlgoClient creates a new instance of AgentService_AlgoClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentService_AlgoClient[Req interface{}, Res interface{}](t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentService_AlgoClient[Req, Res] {
|
||||
mock := &AgentService_AlgoClient[Req, Res]{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,434 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
|
||||
metadata "google.golang.org/grpc/metadata"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// AgentService_DataClient is an autogenerated mock type for the AgentService_DataClient type
|
||||
type AgentService_DataClient[Req interface{}, Res interface{}] struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type AgentService_DataClient_Expecter[Req interface{}, Res interface{}] struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *AgentService_DataClient[Req, Res]) EXPECT() *AgentService_DataClient_Expecter[Req, Res] {
|
||||
return &AgentService_DataClient_Expecter[Req, Res]{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// CloseAndRecv provides a mock function with no fields
|
||||
func (_m *AgentService_DataClient[Req, Res]) CloseAndRecv() (*agent.DataResponse, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseAndRecv")
|
||||
}
|
||||
|
||||
var r0 *agent.DataResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() (*agent.DataResponse, error)); ok {
|
||||
return rf()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() *agent.DataResponse); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*agent.DataResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_DataClient_CloseAndRecv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseAndRecv'
|
||||
type AgentService_DataClient_CloseAndRecv_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseAndRecv is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter[Req, Res]) CloseAndRecv() *AgentService_DataClient_CloseAndRecv_Call[Req, Res] {
|
||||
return &AgentService_DataClient_CloseAndRecv_Call[Req, Res]{Call: _e.mock.On("CloseAndRecv")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseAndRecv_Call[Req, Res]) Run(run func()) *AgentService_DataClient_CloseAndRecv_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseAndRecv_Call[Req, Res]) Return(_a0 *agent.DataResponse, _a1 error) *AgentService_DataClient_CloseAndRecv_Call[Req, Res] {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseAndRecv_Call[Req, Res]) RunAndReturn(run func() (*agent.DataResponse, error)) *AgentService_DataClient_CloseAndRecv_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// CloseSend provides a mock function with no fields
|
||||
func (_m *AgentService_DataClient[Req, Res]) CloseSend() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseSend")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_CloseSend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseSend'
|
||||
type AgentService_DataClient_CloseSend_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseSend is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter[Req, Res]) CloseSend() *AgentService_DataClient_CloseSend_Call[Req, Res] {
|
||||
return &AgentService_DataClient_CloseSend_Call[Req, Res]{Call: _e.mock.On("CloseSend")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseSend_Call[Req, Res]) Run(run func()) *AgentService_DataClient_CloseSend_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseSend_Call[Req, Res]) Return(_a0 error) *AgentService_DataClient_CloseSend_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseSend_Call[Req, Res]) RunAndReturn(run func() error) *AgentService_DataClient_CloseSend_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Context provides a mock function with no fields
|
||||
func (_m *AgentService_DataClient[Req, Res]) Context() context.Context {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Context")
|
||||
}
|
||||
|
||||
var r0 context.Context
|
||||
if rf, ok := ret.Get(0).(func() context.Context); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(context.Context)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context'
|
||||
type AgentService_DataClient_Context_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Context is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter[Req, Res]) Context() *AgentService_DataClient_Context_Call[Req, Res] {
|
||||
return &AgentService_DataClient_Context_Call[Req, Res]{Call: _e.mock.On("Context")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Context_Call[Req, Res]) Run(run func()) *AgentService_DataClient_Context_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Context_Call[Req, Res]) Return(_a0 context.Context) *AgentService_DataClient_Context_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Context_Call[Req, Res]) RunAndReturn(run func() context.Context) *AgentService_DataClient_Context_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Header provides a mock function with no fields
|
||||
func (_m *AgentService_DataClient[Req, Res]) Header() (metadata.MD, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Header")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() (metadata.MD, error)); ok {
|
||||
return rf()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_DataClient_Header_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Header'
|
||||
type AgentService_DataClient_Header_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Header is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter[Req, Res]) Header() *AgentService_DataClient_Header_Call[Req, Res] {
|
||||
return &AgentService_DataClient_Header_Call[Req, Res]{Call: _e.mock.On("Header")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Header_Call[Req, Res]) Run(run func()) *AgentService_DataClient_Header_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Header_Call[Req, Res]) Return(_a0 metadata.MD, _a1 error) *AgentService_DataClient_Header_Call[Req, Res] {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Header_Call[Req, Res]) RunAndReturn(run func() (metadata.MD, error)) *AgentService_DataClient_Header_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RecvMsg provides a mock function with given fields: m
|
||||
func (_m *AgentService_DataClient[Req, Res]) RecvMsg(m interface{}) error {
|
||||
ret := _m.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RecvMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg'
|
||||
type AgentService_DataClient_RecvMsg_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RecvMsg is a helper method to define mock.On call
|
||||
// - m interface{}
|
||||
func (_e *AgentService_DataClient_Expecter[Req, Res]) RecvMsg(m interface{}) *AgentService_DataClient_RecvMsg_Call[Req, Res] {
|
||||
return &AgentService_DataClient_RecvMsg_Call[Req, Res]{Call: _e.mock.On("RecvMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_RecvMsg_Call[Req, Res]) Run(run func(m interface{})) *AgentService_DataClient_RecvMsg_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(interface{}))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_RecvMsg_Call[Req, Res]) Return(_a0 error) *AgentService_DataClient_RecvMsg_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_RecvMsg_Call[Req, Res]) RunAndReturn(run func(interface{}) error) *AgentService_DataClient_RecvMsg_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Send provides a mock function with given fields: _a0
|
||||
func (_m *AgentService_DataClient[Req, Res]) Send(_a0 *agent.DataRequest) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Send")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(*agent.DataRequest) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send'
|
||||
type AgentService_DataClient_Send_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Send is a helper method to define mock.On call
|
||||
// - _a0 *agent.DataRequest
|
||||
func (_e *AgentService_DataClient_Expecter[Req, Res]) Send(_a0 interface{}) *AgentService_DataClient_Send_Call[Req, Res] {
|
||||
return &AgentService_DataClient_Send_Call[Req, Res]{Call: _e.mock.On("Send", _a0)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Send_Call[Req, Res]) Run(run func(_a0 *agent.DataRequest)) *AgentService_DataClient_Send_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*agent.DataRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Send_Call[Req, Res]) Return(_a0 error) *AgentService_DataClient_Send_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Send_Call[Req, Res]) RunAndReturn(run func(*agent.DataRequest) error) *AgentService_DataClient_Send_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendMsg provides a mock function with given fields: m
|
||||
func (_m *AgentService_DataClient[Req, Res]) SendMsg(m interface{}) error {
|
||||
ret := _m.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg'
|
||||
type AgentService_DataClient_SendMsg_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendMsg is a helper method to define mock.On call
|
||||
// - m interface{}
|
||||
func (_e *AgentService_DataClient_Expecter[Req, Res]) SendMsg(m interface{}) *AgentService_DataClient_SendMsg_Call[Req, Res] {
|
||||
return &AgentService_DataClient_SendMsg_Call[Req, Res]{Call: _e.mock.On("SendMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_SendMsg_Call[Req, Res]) Run(run func(m interface{})) *AgentService_DataClient_SendMsg_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(interface{}))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_SendMsg_Call[Req, Res]) Return(_a0 error) *AgentService_DataClient_SendMsg_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_SendMsg_Call[Req, Res]) RunAndReturn(run func(interface{}) error) *AgentService_DataClient_SendMsg_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Trailer provides a mock function with no fields
|
||||
func (_m *AgentService_DataClient[Req, Res]) Trailer() metadata.MD {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Trailer")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
if rf, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_Trailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Trailer'
|
||||
type AgentService_DataClient_Trailer_Call[Req interface{}, Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Trailer is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter[Req, Res]) Trailer() *AgentService_DataClient_Trailer_Call[Req, Res] {
|
||||
return &AgentService_DataClient_Trailer_Call[Req, Res]{Call: _e.mock.On("Trailer")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Trailer_Call[Req, Res]) Run(run func()) *AgentService_DataClient_Trailer_Call[Req, Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Trailer_Call[Req, Res]) Return(_a0 metadata.MD) *AgentService_DataClient_Trailer_Call[Req, Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Trailer_Call[Req, Res]) RunAndReturn(run func() metadata.MD) *AgentService_DataClient_Trailer_Call[Req, Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewAgentService_DataClient creates a new instance of AgentService_DataClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentService_DataClient[Req interface{}, Res interface{}](t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentService_DataClient[Req, Res] {
|
||||
mock := &AgentService_DataClient[Req, Res]{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -0,0 +1,442 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// NewAgentService_AlgoClient creates a new instance of AgentService_AlgoClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentService_AlgoClient(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentService_AlgoClient {
|
||||
mock := &AgentService_AlgoClient{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient is an autogenerated mock type for the AgentService_AlgoClient type
|
||||
type AgentService_AlgoClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type AgentService_AlgoClient_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *AgentService_AlgoClient) EXPECT() *AgentService_AlgoClient_Expecter {
|
||||
return &AgentService_AlgoClient_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// CloseAndRecv provides a mock function for the type AgentService_AlgoClient
|
||||
func (_mock *AgentService_AlgoClient) CloseAndRecv() (*agent.AlgoResponse, error) {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseAndRecv")
|
||||
}
|
||||
|
||||
var r0 *agent.AlgoResponse
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func() (*agent.AlgoResponse, error)); ok {
|
||||
return returnFunc()
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func() *agent.AlgoResponse); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*agent.AlgoResponse)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = returnFunc()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_CloseAndRecv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseAndRecv'
|
||||
type AgentService_AlgoClient_CloseAndRecv_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseAndRecv is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter) CloseAndRecv() *AgentService_AlgoClient_CloseAndRecv_Call {
|
||||
return &AgentService_AlgoClient_CloseAndRecv_Call{Call: _e.mock.On("CloseAndRecv")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseAndRecv_Call) Run(run func()) *AgentService_AlgoClient_CloseAndRecv_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseAndRecv_Call) Return(algoResponse *agent.AlgoResponse, err error) *AgentService_AlgoClient_CloseAndRecv_Call {
|
||||
_c.Call.Return(algoResponse, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseAndRecv_Call) RunAndReturn(run func() (*agent.AlgoResponse, error)) *AgentService_AlgoClient_CloseAndRecv_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// CloseSend provides a mock function for the type AgentService_AlgoClient
|
||||
func (_mock *AgentService_AlgoClient) CloseSend() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseSend")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_CloseSend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseSend'
|
||||
type AgentService_AlgoClient_CloseSend_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseSend is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter) CloseSend() *AgentService_AlgoClient_CloseSend_Call {
|
||||
return &AgentService_AlgoClient_CloseSend_Call{Call: _e.mock.On("CloseSend")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseSend_Call) Run(run func()) *AgentService_AlgoClient_CloseSend_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseSend_Call) Return(err error) *AgentService_AlgoClient_CloseSend_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_CloseSend_Call) RunAndReturn(run func() error) *AgentService_AlgoClient_CloseSend_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Context provides a mock function for the type AgentService_AlgoClient
|
||||
func (_mock *AgentService_AlgoClient) Context() context.Context {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Context")
|
||||
}
|
||||
|
||||
var r0 context.Context
|
||||
if returnFunc, ok := ret.Get(0).(func() context.Context); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(context.Context)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context'
|
||||
type AgentService_AlgoClient_Context_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Context is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter) Context() *AgentService_AlgoClient_Context_Call {
|
||||
return &AgentService_AlgoClient_Context_Call{Call: _e.mock.On("Context")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Context_Call) Run(run func()) *AgentService_AlgoClient_Context_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Context_Call) Return(context1 context.Context) *AgentService_AlgoClient_Context_Call {
|
||||
_c.Call.Return(context1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Context_Call) RunAndReturn(run func() context.Context) *AgentService_AlgoClient_Context_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Header provides a mock function for the type AgentService_AlgoClient
|
||||
func (_mock *AgentService_AlgoClient) Header() (metadata.MD, error) {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Header")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func() (metadata.MD, error)); ok {
|
||||
return returnFunc()
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = returnFunc()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_Header_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Header'
|
||||
type AgentService_AlgoClient_Header_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Header is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter) Header() *AgentService_AlgoClient_Header_Call {
|
||||
return &AgentService_AlgoClient_Header_Call{Call: _e.mock.On("Header")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Header_Call) Run(run func()) *AgentService_AlgoClient_Header_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Header_Call) Return(mD metadata.MD, err error) *AgentService_AlgoClient_Header_Call {
|
||||
_c.Call.Return(mD, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Header_Call) RunAndReturn(run func() (metadata.MD, error)) *AgentService_AlgoClient_Header_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RecvMsg provides a mock function for the type AgentService_AlgoClient
|
||||
func (_mock *AgentService_AlgoClient) RecvMsg(m any) error {
|
||||
ret := _mock.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RecvMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(any) error); ok {
|
||||
r0 = returnFunc(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg'
|
||||
type AgentService_AlgoClient_RecvMsg_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RecvMsg is a helper method to define mock.On call
|
||||
// - m any
|
||||
func (_e *AgentService_AlgoClient_Expecter) RecvMsg(m interface{}) *AgentService_AlgoClient_RecvMsg_Call {
|
||||
return &AgentService_AlgoClient_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_RecvMsg_Call) Run(run func(m any)) *AgentService_AlgoClient_RecvMsg_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 any
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(any)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_RecvMsg_Call) Return(err error) *AgentService_AlgoClient_RecvMsg_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_RecvMsg_Call) RunAndReturn(run func(m any) error) *AgentService_AlgoClient_RecvMsg_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Send provides a mock function for the type AgentService_AlgoClient
|
||||
func (_mock *AgentService_AlgoClient) Send(algoRequest *agent.AlgoRequest) error {
|
||||
ret := _mock.Called(algoRequest)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Send")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(*agent.AlgoRequest) error); ok {
|
||||
r0 = returnFunc(algoRequest)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send'
|
||||
type AgentService_AlgoClient_Send_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Send is a helper method to define mock.On call
|
||||
// - algoRequest *agent.AlgoRequest
|
||||
func (_e *AgentService_AlgoClient_Expecter) Send(algoRequest interface{}) *AgentService_AlgoClient_Send_Call {
|
||||
return &AgentService_AlgoClient_Send_Call{Call: _e.mock.On("Send", algoRequest)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Send_Call) Run(run func(algoRequest *agent.AlgoRequest)) *AgentService_AlgoClient_Send_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 *agent.AlgoRequest
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(*agent.AlgoRequest)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Send_Call) Return(err error) *AgentService_AlgoClient_Send_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Send_Call) RunAndReturn(run func(algoRequest *agent.AlgoRequest) error) *AgentService_AlgoClient_Send_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendMsg provides a mock function for the type AgentService_AlgoClient
|
||||
func (_mock *AgentService_AlgoClient) SendMsg(m any) error {
|
||||
ret := _mock.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(any) error); ok {
|
||||
r0 = returnFunc(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg'
|
||||
type AgentService_AlgoClient_SendMsg_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendMsg is a helper method to define mock.On call
|
||||
// - m any
|
||||
func (_e *AgentService_AlgoClient_Expecter) SendMsg(m interface{}) *AgentService_AlgoClient_SendMsg_Call {
|
||||
return &AgentService_AlgoClient_SendMsg_Call{Call: _e.mock.On("SendMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_SendMsg_Call) Run(run func(m any)) *AgentService_AlgoClient_SendMsg_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 any
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(any)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_SendMsg_Call) Return(err error) *AgentService_AlgoClient_SendMsg_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_SendMsg_Call) RunAndReturn(run func(m any) error) *AgentService_AlgoClient_SendMsg_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Trailer provides a mock function for the type AgentService_AlgoClient
|
||||
func (_mock *AgentService_AlgoClient) Trailer() metadata.MD {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Trailer")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
if returnFunc, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_AlgoClient_Trailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Trailer'
|
||||
type AgentService_AlgoClient_Trailer_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Trailer is a helper method to define mock.On call
|
||||
func (_e *AgentService_AlgoClient_Expecter) Trailer() *AgentService_AlgoClient_Trailer_Call {
|
||||
return &AgentService_AlgoClient_Trailer_Call{Call: _e.mock.On("Trailer")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Trailer_Call) Run(run func()) *AgentService_AlgoClient_Trailer_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Trailer_Call) Return(mD metadata.MD) *AgentService_AlgoClient_Trailer_Call {
|
||||
_c.Call.Return(mD)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_AlgoClient_Trailer_Call) RunAndReturn(run func() metadata.MD) *AgentService_AlgoClient_Trailer_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -0,0 +1,442 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// NewAgentService_DataClient creates a new instance of AgentService_DataClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentService_DataClient(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentService_DataClient {
|
||||
mock := &AgentService_DataClient{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// AgentService_DataClient is an autogenerated mock type for the AgentService_DataClient type
|
||||
type AgentService_DataClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type AgentService_DataClient_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *AgentService_DataClient) EXPECT() *AgentService_DataClient_Expecter {
|
||||
return &AgentService_DataClient_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// CloseAndRecv provides a mock function for the type AgentService_DataClient
|
||||
func (_mock *AgentService_DataClient) CloseAndRecv() (*agent.DataResponse, error) {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseAndRecv")
|
||||
}
|
||||
|
||||
var r0 *agent.DataResponse
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func() (*agent.DataResponse, error)); ok {
|
||||
return returnFunc()
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func() *agent.DataResponse); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*agent.DataResponse)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = returnFunc()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_DataClient_CloseAndRecv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseAndRecv'
|
||||
type AgentService_DataClient_CloseAndRecv_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseAndRecv is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter) CloseAndRecv() *AgentService_DataClient_CloseAndRecv_Call {
|
||||
return &AgentService_DataClient_CloseAndRecv_Call{Call: _e.mock.On("CloseAndRecv")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseAndRecv_Call) Run(run func()) *AgentService_DataClient_CloseAndRecv_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseAndRecv_Call) Return(dataResponse *agent.DataResponse, err error) *AgentService_DataClient_CloseAndRecv_Call {
|
||||
_c.Call.Return(dataResponse, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseAndRecv_Call) RunAndReturn(run func() (*agent.DataResponse, error)) *AgentService_DataClient_CloseAndRecv_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// CloseSend provides a mock function for the type AgentService_DataClient
|
||||
func (_mock *AgentService_DataClient) CloseSend() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseSend")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_CloseSend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseSend'
|
||||
type AgentService_DataClient_CloseSend_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseSend is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter) CloseSend() *AgentService_DataClient_CloseSend_Call {
|
||||
return &AgentService_DataClient_CloseSend_Call{Call: _e.mock.On("CloseSend")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseSend_Call) Run(run func()) *AgentService_DataClient_CloseSend_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseSend_Call) Return(err error) *AgentService_DataClient_CloseSend_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_CloseSend_Call) RunAndReturn(run func() error) *AgentService_DataClient_CloseSend_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Context provides a mock function for the type AgentService_DataClient
|
||||
func (_mock *AgentService_DataClient) Context() context.Context {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Context")
|
||||
}
|
||||
|
||||
var r0 context.Context
|
||||
if returnFunc, ok := ret.Get(0).(func() context.Context); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(context.Context)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context'
|
||||
type AgentService_DataClient_Context_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Context is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter) Context() *AgentService_DataClient_Context_Call {
|
||||
return &AgentService_DataClient_Context_Call{Call: _e.mock.On("Context")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Context_Call) Run(run func()) *AgentService_DataClient_Context_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Context_Call) Return(context1 context.Context) *AgentService_DataClient_Context_Call {
|
||||
_c.Call.Return(context1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Context_Call) RunAndReturn(run func() context.Context) *AgentService_DataClient_Context_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Header provides a mock function for the type AgentService_DataClient
|
||||
func (_mock *AgentService_DataClient) Header() (metadata.MD, error) {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Header")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func() (metadata.MD, error)); ok {
|
||||
return returnFunc()
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = returnFunc()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_DataClient_Header_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Header'
|
||||
type AgentService_DataClient_Header_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Header is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter) Header() *AgentService_DataClient_Header_Call {
|
||||
return &AgentService_DataClient_Header_Call{Call: _e.mock.On("Header")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Header_Call) Run(run func()) *AgentService_DataClient_Header_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Header_Call) Return(mD metadata.MD, err error) *AgentService_DataClient_Header_Call {
|
||||
_c.Call.Return(mD, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Header_Call) RunAndReturn(run func() (metadata.MD, error)) *AgentService_DataClient_Header_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RecvMsg provides a mock function for the type AgentService_DataClient
|
||||
func (_mock *AgentService_DataClient) RecvMsg(m any) error {
|
||||
ret := _mock.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RecvMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(any) error); ok {
|
||||
r0 = returnFunc(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg'
|
||||
type AgentService_DataClient_RecvMsg_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RecvMsg is a helper method to define mock.On call
|
||||
// - m any
|
||||
func (_e *AgentService_DataClient_Expecter) RecvMsg(m interface{}) *AgentService_DataClient_RecvMsg_Call {
|
||||
return &AgentService_DataClient_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_RecvMsg_Call) Run(run func(m any)) *AgentService_DataClient_RecvMsg_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 any
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(any)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_RecvMsg_Call) Return(err error) *AgentService_DataClient_RecvMsg_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_RecvMsg_Call) RunAndReturn(run func(m any) error) *AgentService_DataClient_RecvMsg_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Send provides a mock function for the type AgentService_DataClient
|
||||
func (_mock *AgentService_DataClient) Send(dataRequest *agent.DataRequest) error {
|
||||
ret := _mock.Called(dataRequest)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Send")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(*agent.DataRequest) error); ok {
|
||||
r0 = returnFunc(dataRequest)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send'
|
||||
type AgentService_DataClient_Send_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Send is a helper method to define mock.On call
|
||||
// - dataRequest *agent.DataRequest
|
||||
func (_e *AgentService_DataClient_Expecter) Send(dataRequest interface{}) *AgentService_DataClient_Send_Call {
|
||||
return &AgentService_DataClient_Send_Call{Call: _e.mock.On("Send", dataRequest)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Send_Call) Run(run func(dataRequest *agent.DataRequest)) *AgentService_DataClient_Send_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 *agent.DataRequest
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(*agent.DataRequest)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Send_Call) Return(err error) *AgentService_DataClient_Send_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Send_Call) RunAndReturn(run func(dataRequest *agent.DataRequest) error) *AgentService_DataClient_Send_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendMsg provides a mock function for the type AgentService_DataClient
|
||||
func (_mock *AgentService_DataClient) SendMsg(m any) error {
|
||||
ret := _mock.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(any) error); ok {
|
||||
r0 = returnFunc(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg'
|
||||
type AgentService_DataClient_SendMsg_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendMsg is a helper method to define mock.On call
|
||||
// - m any
|
||||
func (_e *AgentService_DataClient_Expecter) SendMsg(m interface{}) *AgentService_DataClient_SendMsg_Call {
|
||||
return &AgentService_DataClient_SendMsg_Call{Call: _e.mock.On("SendMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_SendMsg_Call) Run(run func(m any)) *AgentService_DataClient_SendMsg_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 any
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(any)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_SendMsg_Call) Return(err error) *AgentService_DataClient_SendMsg_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_SendMsg_Call) RunAndReturn(run func(m any) error) *AgentService_DataClient_SendMsg_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Trailer provides a mock function for the type AgentService_DataClient
|
||||
func (_mock *AgentService_DataClient) Trailer() metadata.MD {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Trailer")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
if returnFunc, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_DataClient_Trailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Trailer'
|
||||
type AgentService_DataClient_Trailer_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Trailer is a helper method to define mock.On call
|
||||
func (_e *AgentService_DataClient_Expecter) Trailer() *AgentService_DataClient_Trailer_Call {
|
||||
return &AgentService_DataClient_Trailer_Call{Call: _e.mock.On("Trailer")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Trailer_Call) Run(run func()) *AgentService_DataClient_Trailer_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Trailer_Call) Return(mD metadata.MD) *AgentService_DataClient_Trailer_Call {
|
||||
_c.Call.Return(mD)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_DataClient_Trailer_Call) RunAndReturn(run func() metadata.MD) *AgentService_DataClient_Trailer_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// NewAgentService_IMAMeasurementsClient creates a new instance of AgentService_IMAMeasurementsClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentService_IMAMeasurementsClient(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentService_IMAMeasurementsClient {
|
||||
mock := &AgentService_IMAMeasurementsClient{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient is an autogenerated mock type for the AgentService_IMAMeasurementsClient type
|
||||
type AgentService_IMAMeasurementsClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type AgentService_IMAMeasurementsClient_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *AgentService_IMAMeasurementsClient) EXPECT() *AgentService_IMAMeasurementsClient_Expecter {
|
||||
return &AgentService_IMAMeasurementsClient_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// CloseSend provides a mock function for the type AgentService_IMAMeasurementsClient
|
||||
func (_mock *AgentService_IMAMeasurementsClient) CloseSend() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseSend")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_CloseSend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseSend'
|
||||
type AgentService_IMAMeasurementsClient_CloseSend_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseSend is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter) CloseSend() *AgentService_IMAMeasurementsClient_CloseSend_Call {
|
||||
return &AgentService_IMAMeasurementsClient_CloseSend_Call{Call: _e.mock.On("CloseSend")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_CloseSend_Call) Run(run func()) *AgentService_IMAMeasurementsClient_CloseSend_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_CloseSend_Call) Return(err error) *AgentService_IMAMeasurementsClient_CloseSend_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_CloseSend_Call) RunAndReturn(run func() error) *AgentService_IMAMeasurementsClient_CloseSend_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Context provides a mock function for the type AgentService_IMAMeasurementsClient
|
||||
func (_mock *AgentService_IMAMeasurementsClient) Context() context.Context {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Context")
|
||||
}
|
||||
|
||||
var r0 context.Context
|
||||
if returnFunc, ok := ret.Get(0).(func() context.Context); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(context.Context)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context'
|
||||
type AgentService_IMAMeasurementsClient_Context_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Context is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter) Context() *AgentService_IMAMeasurementsClient_Context_Call {
|
||||
return &AgentService_IMAMeasurementsClient_Context_Call{Call: _e.mock.On("Context")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Context_Call) Run(run func()) *AgentService_IMAMeasurementsClient_Context_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Context_Call) Return(context1 context.Context) *AgentService_IMAMeasurementsClient_Context_Call {
|
||||
_c.Call.Return(context1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Context_Call) RunAndReturn(run func() context.Context) *AgentService_IMAMeasurementsClient_Context_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Header provides a mock function for the type AgentService_IMAMeasurementsClient
|
||||
func (_mock *AgentService_IMAMeasurementsClient) Header() (metadata.MD, error) {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Header")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func() (metadata.MD, error)); ok {
|
||||
return returnFunc()
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = returnFunc()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_Header_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Header'
|
||||
type AgentService_IMAMeasurementsClient_Header_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Header is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter) Header() *AgentService_IMAMeasurementsClient_Header_Call {
|
||||
return &AgentService_IMAMeasurementsClient_Header_Call{Call: _e.mock.On("Header")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Header_Call) Run(run func()) *AgentService_IMAMeasurementsClient_Header_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Header_Call) Return(mD metadata.MD, err error) *AgentService_IMAMeasurementsClient_Header_Call {
|
||||
_c.Call.Return(mD, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Header_Call) RunAndReturn(run func() (metadata.MD, error)) *AgentService_IMAMeasurementsClient_Header_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Recv provides a mock function for the type AgentService_IMAMeasurementsClient
|
||||
func (_mock *AgentService_IMAMeasurementsClient) Recv() (*agent.IMAMeasurementsResponse, error) {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Recv")
|
||||
}
|
||||
|
||||
var r0 *agent.IMAMeasurementsResponse
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func() (*agent.IMAMeasurementsResponse, error)); ok {
|
||||
return returnFunc()
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func() *agent.IMAMeasurementsResponse); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*agent.IMAMeasurementsResponse)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = returnFunc()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_Recv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Recv'
|
||||
type AgentService_IMAMeasurementsClient_Recv_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Recv is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter) Recv() *AgentService_IMAMeasurementsClient_Recv_Call {
|
||||
return &AgentService_IMAMeasurementsClient_Recv_Call{Call: _e.mock.On("Recv")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Recv_Call) Run(run func()) *AgentService_IMAMeasurementsClient_Recv_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Recv_Call) Return(iMAMeasurementsResponse *agent.IMAMeasurementsResponse, err error) *AgentService_IMAMeasurementsClient_Recv_Call {
|
||||
_c.Call.Return(iMAMeasurementsResponse, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Recv_Call) RunAndReturn(run func() (*agent.IMAMeasurementsResponse, error)) *AgentService_IMAMeasurementsClient_Recv_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RecvMsg provides a mock function for the type AgentService_IMAMeasurementsClient
|
||||
func (_mock *AgentService_IMAMeasurementsClient) RecvMsg(m any) error {
|
||||
ret := _mock.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RecvMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(any) error); ok {
|
||||
r0 = returnFunc(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg'
|
||||
type AgentService_IMAMeasurementsClient_RecvMsg_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RecvMsg is a helper method to define mock.On call
|
||||
// - m any
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter) RecvMsg(m interface{}) *AgentService_IMAMeasurementsClient_RecvMsg_Call {
|
||||
return &AgentService_IMAMeasurementsClient_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_RecvMsg_Call) Run(run func(m any)) *AgentService_IMAMeasurementsClient_RecvMsg_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 any
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(any)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_RecvMsg_Call) Return(err error) *AgentService_IMAMeasurementsClient_RecvMsg_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_RecvMsg_Call) RunAndReturn(run func(m any) error) *AgentService_IMAMeasurementsClient_RecvMsg_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendMsg provides a mock function for the type AgentService_IMAMeasurementsClient
|
||||
func (_mock *AgentService_IMAMeasurementsClient) SendMsg(m any) error {
|
||||
ret := _mock.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(any) error); ok {
|
||||
r0 = returnFunc(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg'
|
||||
type AgentService_IMAMeasurementsClient_SendMsg_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendMsg is a helper method to define mock.On call
|
||||
// - m any
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter) SendMsg(m interface{}) *AgentService_IMAMeasurementsClient_SendMsg_Call {
|
||||
return &AgentService_IMAMeasurementsClient_SendMsg_Call{Call: _e.mock.On("SendMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_SendMsg_Call) Run(run func(m any)) *AgentService_IMAMeasurementsClient_SendMsg_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 any
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(any)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_SendMsg_Call) Return(err error) *AgentService_IMAMeasurementsClient_SendMsg_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_SendMsg_Call) RunAndReturn(run func(m any) error) *AgentService_IMAMeasurementsClient_SendMsg_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Trailer provides a mock function for the type AgentService_IMAMeasurementsClient
|
||||
func (_mock *AgentService_IMAMeasurementsClient) Trailer() metadata.MD {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Trailer")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
if returnFunc, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_Trailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Trailer'
|
||||
type AgentService_IMAMeasurementsClient_Trailer_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Trailer is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter) Trailer() *AgentService_IMAMeasurementsClient_Trailer_Call {
|
||||
return &AgentService_IMAMeasurementsClient_Trailer_Call{Call: _e.mock.On("Trailer")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Trailer_Call) Run(run func()) *AgentService_IMAMeasurementsClient_Trailer_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Trailer_Call) Return(mD metadata.MD) *AgentService_IMAMeasurementsClient_Trailer_Call {
|
||||
_c.Call.Return(mD)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Trailer_Call) RunAndReturn(run func() metadata.MD) *AgentService_IMAMeasurementsClient_Trailer_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -0,0 +1,589 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Service_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Service) EXPECT() *Service_Expecter {
|
||||
return &Service_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Algo provides a mock function for the type Service
|
||||
func (_mock *Service) Algo(ctx context.Context, algorithm agent.Algorithm) error {
|
||||
ret := _mock.Called(ctx, algorithm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Algo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, agent.Algorithm) error); ok {
|
||||
r0 = returnFunc(ctx, algorithm)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_Algo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Algo'
|
||||
type Service_Algo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Algo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - algorithm agent.Algorithm
|
||||
func (_e *Service_Expecter) Algo(ctx interface{}, algorithm interface{}) *Service_Algo_Call {
|
||||
return &Service_Algo_Call{Call: _e.mock.On("Algo", ctx, algorithm)}
|
||||
}
|
||||
|
||||
func (_c *Service_Algo_Call) Run(run func(ctx context.Context, algorithm agent.Algorithm)) *Service_Algo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 agent.Algorithm
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(agent.Algorithm)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Algo_Call) Return(err error) *Service_Algo_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Algo_Call) RunAndReturn(run func(ctx context.Context, algorithm agent.Algorithm) error) *Service_Algo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Attestation provides a mock function for the type Service
|
||||
func (_mock *Service) Attestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
ret := _mock.Called(ctx, reportData, nonce, attType)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Attestation")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, attestation.PlatformType) ([]byte, error)); ok {
|
||||
return returnFunc(ctx, reportData, nonce, attType)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, attestation.PlatformType) []byte); ok {
|
||||
r0 = returnFunc(ctx, reportData, nonce, attType)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, [64]byte, [32]byte, attestation.PlatformType) error); ok {
|
||||
r1 = returnFunc(ctx, reportData, nonce, attType)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_Attestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Attestation'
|
||||
type Service_Attestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Attestation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - reportData [64]byte
|
||||
// - nonce [32]byte
|
||||
// - attType attestation.PlatformType
|
||||
func (_e *Service_Expecter) Attestation(ctx interface{}, reportData interface{}, nonce interface{}, attType interface{}) *Service_Attestation_Call {
|
||||
return &Service_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, nonce, attType)}
|
||||
}
|
||||
|
||||
func (_c *Service_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType)) *Service_Attestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 [64]byte
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].([64]byte)
|
||||
}
|
||||
var arg2 [32]byte
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].([32]byte)
|
||||
}
|
||||
var arg3 attestation.PlatformType
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(attestation.PlatformType)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Attestation_Call) Return(bytes []byte, err error) *Service_Attestation_Call {
|
||||
_c.Call.Return(bytes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Attestation_Call) RunAndReturn(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error)) *Service_Attestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// AzureAttestationToken provides a mock function for the type Service
|
||||
func (_mock *Service) AzureAttestationToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
|
||||
ret := _mock.Called(ctx, nonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AzureAttestationToken")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, [32]byte) ([]byte, error)); ok {
|
||||
return returnFunc(ctx, nonce)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, [32]byte) []byte); ok {
|
||||
r0 = returnFunc(ctx, nonce)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, [32]byte) error); ok {
|
||||
r1 = returnFunc(ctx, nonce)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_AzureAttestationToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AzureAttestationToken'
|
||||
type Service_AzureAttestationToken_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AzureAttestationToken is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - nonce [32]byte
|
||||
func (_e *Service_Expecter) AzureAttestationToken(ctx interface{}, nonce interface{}) *Service_AzureAttestationToken_Call {
|
||||
return &Service_AzureAttestationToken_Call{Call: _e.mock.On("AzureAttestationToken", ctx, nonce)}
|
||||
}
|
||||
|
||||
func (_c *Service_AzureAttestationToken_Call) Run(run func(ctx context.Context, nonce [32]byte)) *Service_AzureAttestationToken_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 [32]byte
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].([32]byte)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_AzureAttestationToken_Call) Return(bytes []byte, err error) *Service_AzureAttestationToken_Call {
|
||||
_c.Call.Return(bytes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_AzureAttestationToken_Call) RunAndReturn(run func(ctx context.Context, nonce [32]byte) ([]byte, error)) *Service_AzureAttestationToken_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Data provides a mock function for the type Service
|
||||
func (_mock *Service) Data(ctx context.Context, dataset agent.Dataset) error {
|
||||
ret := _mock.Called(ctx, dataset)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Data")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, agent.Dataset) error); ok {
|
||||
r0 = returnFunc(ctx, dataset)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_Data_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Data'
|
||||
type Service_Data_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Data is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - dataset agent.Dataset
|
||||
func (_e *Service_Expecter) Data(ctx interface{}, dataset interface{}) *Service_Data_Call {
|
||||
return &Service_Data_Call{Call: _e.mock.On("Data", ctx, dataset)}
|
||||
}
|
||||
|
||||
func (_c *Service_Data_Call) Run(run func(ctx context.Context, dataset agent.Dataset)) *Service_Data_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 agent.Dataset
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(agent.Dataset)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Data_Call) Return(err error) *Service_Data_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Data_Call) RunAndReturn(run func(ctx context.Context, dataset agent.Dataset) error) *Service_Data_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// IMAMeasurements provides a mock function for the type Service
|
||||
func (_mock *Service) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) {
|
||||
ret := _mock.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for IMAMeasurements")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 []byte
|
||||
var r2 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context) ([]byte, []byte, error)); ok {
|
||||
return returnFunc(ctx)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context) []byte); ok {
|
||||
r0 = returnFunc(ctx)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context) []byte); ok {
|
||||
r1 = returnFunc(ctx)
|
||||
} else {
|
||||
if ret.Get(1) != nil {
|
||||
r1 = ret.Get(1).([]byte)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(2).(func(context.Context) error); ok {
|
||||
r2 = returnFunc(ctx)
|
||||
} else {
|
||||
r2 = ret.Error(2)
|
||||
}
|
||||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// Service_IMAMeasurements_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IMAMeasurements'
|
||||
type Service_IMAMeasurements_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// IMAMeasurements is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *Service_Expecter) IMAMeasurements(ctx interface{}) *Service_IMAMeasurements_Call {
|
||||
return &Service_IMAMeasurements_Call{Call: _e.mock.On("IMAMeasurements", ctx)}
|
||||
}
|
||||
|
||||
func (_c *Service_IMAMeasurements_Call) Run(run func(ctx context.Context)) *Service_IMAMeasurements_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_IMAMeasurements_Call) Return(bytes []byte, bytes1 []byte, err error) *Service_IMAMeasurements_Call {
|
||||
_c.Call.Return(bytes, bytes1, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_IMAMeasurements_Call) RunAndReturn(run func(ctx context.Context) ([]byte, []byte, error)) *Service_IMAMeasurements_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// InitComputation provides a mock function for the type Service
|
||||
func (_mock *Service) InitComputation(ctx context.Context, cmp agent.Computation) error {
|
||||
ret := _mock.Called(ctx, cmp)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for InitComputation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, agent.Computation) error); ok {
|
||||
r0 = returnFunc(ctx, cmp)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_InitComputation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitComputation'
|
||||
type Service_InitComputation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InitComputation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cmp agent.Computation
|
||||
func (_e *Service_Expecter) InitComputation(ctx interface{}, cmp interface{}) *Service_InitComputation_Call {
|
||||
return &Service_InitComputation_Call{Call: _e.mock.On("InitComputation", ctx, cmp)}
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) Run(run func(ctx context.Context, cmp agent.Computation)) *Service_InitComputation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 agent.Computation
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(agent.Computation)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) Return(err error) *Service_InitComputation_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) RunAndReturn(run func(ctx context.Context, cmp agent.Computation) error) *Service_InitComputation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Result provides a mock function for the type Service
|
||||
func (_mock *Service) Result(ctx context.Context) ([]byte, error) {
|
||||
ret := _mock.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Result")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context) ([]byte, error)); ok {
|
||||
return returnFunc(ctx)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context) []byte); ok {
|
||||
r0 = returnFunc(ctx)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context) error); ok {
|
||||
r1 = returnFunc(ctx)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_Result_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Result'
|
||||
type Service_Result_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Result is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *Service_Expecter) Result(ctx interface{}) *Service_Result_Call {
|
||||
return &Service_Result_Call{Call: _e.mock.On("Result", ctx)}
|
||||
}
|
||||
|
||||
func (_c *Service_Result_Call) Run(run func(ctx context.Context)) *Service_Result_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Result_Call) Return(bytes []byte, err error) *Service_Result_Call {
|
||||
_c.Call.Return(bytes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Result_Call) RunAndReturn(run func(ctx context.Context) ([]byte, error)) *Service_Result_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// State provides a mock function for the type Service
|
||||
func (_mock *Service) State() string {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for State")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
if returnFunc, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_State_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'State'
|
||||
type Service_State_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// State is a helper method to define mock.On call
|
||||
func (_e *Service_Expecter) State() *Service_State_Call {
|
||||
return &Service_State_Call{Call: _e.mock.On("State")}
|
||||
}
|
||||
|
||||
func (_c *Service_State_Call) Run(run func()) *Service_State_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_State_Call) Return(s string) *Service_State_Call {
|
||||
_c.Call.Return(s)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_State_Call) RunAndReturn(run func() string) *Service_State_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// StopComputation provides a mock function for the type Service
|
||||
func (_mock *Service) StopComputation(ctx context.Context) error {
|
||||
ret := _mock.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for StopComputation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = returnFunc(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_StopComputation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StopComputation'
|
||||
type Service_StopComputation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// StopComputation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *Service_Expecter) StopComputation(ctx interface{}) *Service_StopComputation_Call {
|
||||
return &Service_StopComputation_Call{Call: _e.mock.On("StopComputation", ctx)}
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) Run(run func(ctx context.Context)) *Service_StopComputation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) Return(err error) *Service_StopComputation_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) RunAndReturn(run func(ctx context.Context) error) *Service_StopComputation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/pkg/resource"
|
||||
)
|
||||
|
||||
type MockDownloader struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockDownloader) Download(ctx context.Context, url string, destPath string) error {
|
||||
args := m.Called(ctx, url, destPath)
|
||||
if args.Error(0) == nil {
|
||||
// Simulate writing to destPath if it's a success
|
||||
content := "mock content"
|
||||
if len(args) > 1 {
|
||||
if c, ok := args.Get(1).(string); ok {
|
||||
content = c
|
||||
}
|
||||
}
|
||||
_ = os.MkdirAll(filepath.Dir(destPath), 0o755)
|
||||
_ = os.WriteFile(destPath, []byte(content), 0o644)
|
||||
}
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockDownloader) Type() string {
|
||||
return m.Called().String(0)
|
||||
}
|
||||
|
||||
func TestDownloadAndDecryptGenericResource(t *testing.T) {
|
||||
registry := resource.NewRegistry()
|
||||
mockDownloader := new(MockDownloader)
|
||||
mockDownloader.On("Type").Return(resource.SourceTypeHTTP)
|
||||
registry.Register(mockDownloader)
|
||||
|
||||
svc := &agentService{
|
||||
logger: slog.Default(),
|
||||
resourceRegistry: registry,
|
||||
computation: Computation{
|
||||
Algorithm: &Algorithm{
|
||||
KBS: &KBSConfig{
|
||||
Enabled: true,
|
||||
URL: "http://mock-kbs",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Successful download without encryption", func(t *testing.T) {
|
||||
source := &ResourceSource{
|
||||
URL: "http://example.com/resource",
|
||||
}
|
||||
destPath := filepath.Join(os.TempDir(), "cocos-resources", "algo", "resource")
|
||||
mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, "some data").Once()
|
||||
|
||||
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "", "algo")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("some data"), res.Data)
|
||||
mockDownloader.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Successful download with encryption", func(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, _ = io.ReadFull(rand.Reader, key)
|
||||
|
||||
plaintext := []byte("secret data")
|
||||
block, _ := aes.NewCipher(key)
|
||||
gcm, _ := cipher.NewGCM(block)
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
_, _ = io.ReadFull(rand.Reader, nonce)
|
||||
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
|
||||
// Mock KBS
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(key)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
svc.computation.Algorithm.KBS.URL = ts.URL
|
||||
|
||||
source := &ResourceSource{
|
||||
URL: "http://example.com/encrypted",
|
||||
Encrypted: true,
|
||||
KBSResourcePath: "keys/1",
|
||||
}
|
||||
destPath := filepath.Join(os.TempDir(), "cocos-resources", "data", "encrypted")
|
||||
mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, string(ciphertext)).Once()
|
||||
|
||||
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, svc.computation.Algorithm.KBS.URL, "data")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, res.Data)
|
||||
mockDownloader.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Registry not initialized", func(t *testing.T) {
|
||||
badSvc := &agentService{logger: slog.Default()}
|
||||
_, err := badSvc.downloadAndDecryptGenericResource(ctx, &ResourceSource{}, "http", "", "algo")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "resource registry not initialized")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetKeyFromKBS(t *testing.T) {
|
||||
svc := &agentService{
|
||||
logger: slog.Default(),
|
||||
computation: Computation{
|
||||
Algorithm: &Algorithm{
|
||||
KBS: &KBSConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("KBS disabled", func(t *testing.T) {
|
||||
svc.computation.Algorithm.KBS.Enabled = false
|
||||
_, err := svc.getKeyFromKBS(ctx, "", "path")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Successful fetch", func(t *testing.T) {
|
||||
svc.computation.Algorithm.KBS.Enabled = true
|
||||
key := []byte("this is a 32-byte key!!!!!!!!!!!")
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Contains(t, r.URL.Path, "resource/path")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(key)
|
||||
}))
|
||||
defer ts.Close()
|
||||
svc.computation.Algorithm.KBS.URL = ts.URL
|
||||
|
||||
fetched, err := svc.getKeyFromKBS(ctx, svc.computation.Algorithm.KBS.URL, "path")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key, fetched)
|
||||
})
|
||||
|
||||
t.Run("KBS error", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer ts.Close()
|
||||
svc.computation.Algorithm.KBS.URL = ts.URL
|
||||
|
||||
_, err := svc.getKeyFromKBS(ctx, svc.computation.Algorithm.KBS.URL, "path")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInferSourceTypeDetailed(t *testing.T) {
|
||||
tests := []struct {
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{"s3://bucket/key", resource.SourceTypeS3},
|
||||
{"gs://bucket/key", resource.SourceTypeGCS},
|
||||
{"https://example.com/file", resource.SourceTypeHTTPS},
|
||||
{"http://example.com/file", resource.SourceTypeHTTP},
|
||||
{"docker://ubuntu", resource.SourceTypeOCIImage},
|
||||
{"oci:/path/to/dir", resource.SourceTypeOCIImage},
|
||||
{"ubuntu:latest", resource.SourceTypeOCIImage},
|
||||
{"myregistry.io/myimage:tag", resource.SourceTypeOCIImage},
|
||||
{"invalid-url-no-slash", ""},
|
||||
{"", ""},
|
||||
{"ftp://server/file", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
assert.Equal(t, tt.expected, inferSourceType(tt.url), "URL: %s", tt.url)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
logpb "github.com/ultravioletrs/cocos/agent/log"
|
||||
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
|
||||
)
|
||||
|
||||
type adapter struct {
|
||||
client logclient.Client
|
||||
svc string
|
||||
}
|
||||
|
||||
func NewAdapter(client logclient.Client, svc string) events.Service {
|
||||
return &adapter{
|
||||
client: client,
|
||||
svc: svc,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *adapter) SendEvent(cmpID, event, status string, details json.RawMessage) {
|
||||
err := a.client.SendEvent(context.Background(), &logpb.EventEntry{
|
||||
EventType: event,
|
||||
ComputationId: cmpID,
|
||||
Details: details,
|
||||
Originator: a.svc,
|
||||
Status: status,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to send event to log-forwarder", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
logpb "github.com/ultravioletrs/cocos/agent/log"
|
||||
)
|
||||
|
||||
const testServiceName = "test-service"
|
||||
|
||||
// mockLogClient is a mock implementation of the log client.
|
||||
type mockLogClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockLogClient) SendLog(ctx context.Context, entry *logpb.LogEntry) error {
|
||||
args := m.Called(ctx, entry)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockLogClient) SendEvent(ctx context.Context, entry *logpb.EventEntry) error {
|
||||
args := m.Called(ctx, entry)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockLogClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// TestNewAdapter tests creating a new adapter.
|
||||
func TestNewAdapter(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := testServiceName
|
||||
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
assert.NotNil(t, adapter)
|
||||
}
|
||||
|
||||
// TestSendEvent tests sending an event successfully.
|
||||
func TestSendEvent(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := testServiceName
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
cmpID := "test-computation-id"
|
||||
event := "computation.started"
|
||||
status := "success"
|
||||
details := json.RawMessage(`{"key": "value"}`)
|
||||
|
||||
expectedEntry := &logpb.EventEntry{
|
||||
EventType: event,
|
||||
ComputationId: cmpID,
|
||||
Details: details,
|
||||
Originator: svc,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
|
||||
|
||||
adapter.SendEvent(cmpID, event, status, details)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
mockClient.AssertCalled(t, "SendEvent", mock.Anything, expectedEntry)
|
||||
}
|
||||
|
||||
// TestSendEventWithError tests sending an event when client returns an error.
|
||||
func TestSendEventWithError(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := testServiceName
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
cmpID := "test-computation-id"
|
||||
event := "computation.failed"
|
||||
status := "error"
|
||||
details := json.RawMessage(`{"error": "something went wrong"}`)
|
||||
|
||||
mockClient.On("SendEvent", mock.Anything, mock.Anything).Return(assert.AnError)
|
||||
|
||||
// This should not panic even when error occurs
|
||||
adapter.SendEvent(cmpID, event, status, details)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
mockClient.AssertCalled(t, "SendEvent", mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
// TestSendEventWithNilDetails tests sending an event with nil details.
|
||||
func TestSendEventWithNilDetails(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := "runner-service"
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
cmpID := "comp-123"
|
||||
event := "test.event"
|
||||
status := "pending"
|
||||
|
||||
expectedEntry := &logpb.EventEntry{
|
||||
EventType: event,
|
||||
ComputationId: cmpID,
|
||||
Details: nil,
|
||||
Originator: svc,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
|
||||
|
||||
adapter.SendEvent(cmpID, event, status, nil)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendEventWithEmptyStrings tests sending an event with empty strings.
|
||||
func TestSendEventWithEmptyStrings(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := testServiceName
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
expectedEntry := &logpb.EventEntry{
|
||||
EventType: "",
|
||||
ComputationId: "",
|
||||
Details: nil,
|
||||
Originator: svc,
|
||||
Status: "",
|
||||
}
|
||||
|
||||
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
|
||||
|
||||
adapter.SendEvent("", "", "", nil)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
}
|
||||
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.11
|
||||
// protoc v6.33.1
|
||||
// source: agent/runner/runner.proto
|
||||
|
||||
package runner
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type RunRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
AlgoType string `protobuf:"bytes,2,opt,name=algo_type,json=algoType,proto3" json:"algo_type,omitempty"` // "binary", "python", "wasm", "docker"
|
||||
Algorithm []byte `protobuf:"bytes,3,opt,name=algorithm,proto3" json:"algorithm,omitempty"` // The algorithm binary/script content
|
||||
Requirements []byte `protobuf:"bytes,4,opt,name=requirements,proto3" json:"requirements,omitempty"` // Python requirements.txt content
|
||||
Args []string `protobuf:"bytes,5,rep,name=args,proto3" json:"args,omitempty"`
|
||||
Datasets []*Dataset `protobuf:"bytes,6,rep,name=datasets,proto3" json:"datasets,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RunRequest) Reset() {
|
||||
*x = RunRequest{}
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RunRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RunRequest) ProtoMessage() {}
|
||||
|
||||
func (x *RunRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RunRequest.ProtoReflect.Descriptor instead.
|
||||
func (*RunRequest) Descriptor() ([]byte, []int) {
|
||||
return file_agent_runner_runner_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetAlgoType() string {
|
||||
if x != nil {
|
||||
return x.AlgoType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetAlgorithm() []byte {
|
||||
if x != nil {
|
||||
return x.Algorithm
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetRequirements() []byte {
|
||||
if x != nil {
|
||||
return x.Requirements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetArgs() []string {
|
||||
if x != nil {
|
||||
return x.Args
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetDatasets() []*Dataset {
|
||||
if x != nil {
|
||||
return x.Datasets
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Dataset struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Filename string `protobuf:"bytes,1,opt,name=filename,proto3" json:"filename,omitempty"`
|
||||
Hash []byte `protobuf:"bytes,2,opt,name=hash,proto3" json:"hash,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *Dataset) Reset() {
|
||||
*x = Dataset{}
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *Dataset) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Dataset) ProtoMessage() {}
|
||||
|
||||
func (x *Dataset) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[1]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Dataset.ProtoReflect.Descriptor instead.
|
||||
func (*Dataset) Descriptor() ([]byte, []int) {
|
||||
return file_agent_runner_runner_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *Dataset) GetFilename() string {
|
||||
if x != nil {
|
||||
return x.Filename
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Dataset) GetHash() []byte {
|
||||
if x != nil {
|
||||
return x.Hash
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RunResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RunResponse) Reset() {
|
||||
*x = RunResponse{}
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RunResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RunResponse) ProtoMessage() {}
|
||||
|
||||
func (x *RunResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[2]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RunResponse.ProtoReflect.Descriptor instead.
|
||||
func (*RunResponse) Descriptor() ([]byte, []int) {
|
||||
return file_agent_runner_runner_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *RunResponse) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RunResponse) GetError() string {
|
||||
if x != nil {
|
||||
return x.Error
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type StopRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StopRequest) Reset() {
|
||||
*x = StopRequest{}
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StopRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StopRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[3]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StopRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopRequest) Descriptor() ([]byte, []int) {
|
||||
return file_agent_runner_runner_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *StopRequest) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_agent_runner_runner_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_agent_runner_runner_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x19agent/runner/runner.proto\x12\x06runner\x1a\x1bgoogle/protobuf/empty.proto\"\xd3\x01\n" +
|
||||
"\n" +
|
||||
"RunRequest\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x1b\n" +
|
||||
"\talgo_type\x18\x02 \x01(\tR\balgoType\x12\x1c\n" +
|
||||
"\talgorithm\x18\x03 \x01(\fR\talgorithm\x12\"\n" +
|
||||
"\frequirements\x18\x04 \x01(\fR\frequirements\x12\x12\n" +
|
||||
"\x04args\x18\x05 \x03(\tR\x04args\x12+\n" +
|
||||
"\bdatasets\x18\x06 \x03(\v2\x0f.runner.DatasetR\bdatasets\"9\n" +
|
||||
"\aDataset\x12\x1a\n" +
|
||||
"\bfilename\x18\x01 \x01(\tR\bfilename\x12\x12\n" +
|
||||
"\x04hash\x18\x02 \x01(\fR\x04hash\"J\n" +
|
||||
"\vRunResponse\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x14\n" +
|
||||
"\x05error\x18\x02 \x01(\tR\x05error\"4\n" +
|
||||
"\vStopRequest\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId2x\n" +
|
||||
"\x11ComputationRunner\x12.\n" +
|
||||
"\x03Run\x12\x12.runner.RunRequest\x1a\x13.runner.RunResponse\x123\n" +
|
||||
"\x04Stop\x12\x13.runner.StopRequest\x1a\x16.google.protobuf.EmptyB\n" +
|
||||
"Z\b./runnerb\x06proto3"
|
||||
|
||||
var (
|
||||
file_agent_runner_runner_proto_rawDescOnce sync.Once
|
||||
file_agent_runner_runner_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_agent_runner_runner_proto_rawDescGZIP() []byte {
|
||||
file_agent_runner_runner_proto_rawDescOnce.Do(func() {
|
||||
file_agent_runner_runner_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_runner_runner_proto_rawDesc), len(file_agent_runner_runner_proto_rawDesc)))
|
||||
})
|
||||
return file_agent_runner_runner_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_runner_runner_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
|
||||
var file_agent_runner_runner_proto_goTypes = []any{
|
||||
(*RunRequest)(nil), // 0: runner.RunRequest
|
||||
(*Dataset)(nil), // 1: runner.Dataset
|
||||
(*RunResponse)(nil), // 2: runner.RunResponse
|
||||
(*StopRequest)(nil), // 3: runner.StopRequest
|
||||
(*emptypb.Empty)(nil), // 4: google.protobuf.Empty
|
||||
}
|
||||
var file_agent_runner_runner_proto_depIdxs = []int32{
|
||||
1, // 0: runner.RunRequest.datasets:type_name -> runner.Dataset
|
||||
0, // 1: runner.ComputationRunner.Run:input_type -> runner.RunRequest
|
||||
3, // 2: runner.ComputationRunner.Stop:input_type -> runner.StopRequest
|
||||
2, // 3: runner.ComputationRunner.Run:output_type -> runner.RunResponse
|
||||
4, // 4: runner.ComputationRunner.Stop:output_type -> google.protobuf.Empty
|
||||
3, // [3:5] is the sub-list for method output_type
|
||||
1, // [1:3] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_runner_runner_proto_init() }
|
||||
func file_agent_runner_runner_proto_init() {
|
||||
if File_agent_runner_runner_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_runner_runner_proto_rawDesc), len(file_agent_runner_runner_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 4,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_agent_runner_runner_proto_goTypes,
|
||||
DependencyIndexes: file_agent_runner_runner_proto_depIdxs,
|
||||
MessageInfos: file_agent_runner_runner_proto_msgTypes,
|
||||
}.Build()
|
||||
File_agent_runner_runner_proto = out.File
|
||||
file_agent_runner_runner_proto_goTypes = nil
|
||||
file_agent_runner_runner_proto_depIdxs = nil
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package runner;
|
||||
|
||||
option go_package = "./runner";
|
||||
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
service ComputationRunner {
|
||||
rpc Run(RunRequest) returns (RunResponse);
|
||||
rpc Stop(StopRequest) returns (google.protobuf.Empty);
|
||||
}
|
||||
|
||||
message RunRequest {
|
||||
string computation_id = 1;
|
||||
string algo_type = 2; // "binary", "python", "wasm", "docker"
|
||||
bytes algorithm = 3; // The algorithm binary/script content
|
||||
bytes requirements = 4; // Python requirements.txt content
|
||||
repeated string args = 5;
|
||||
repeated Dataset datasets = 6;
|
||||
}
|
||||
|
||||
message Dataset {
|
||||
string filename = 1;
|
||||
bytes hash = 2;
|
||||
}
|
||||
|
||||
message RunResponse {
|
||||
string computation_id = 1;
|
||||
string error = 2;
|
||||
}
|
||||
|
||||
message StopRequest {
|
||||
string computation_id = 1;
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.0
|
||||
// - protoc v6.33.1
|
||||
// source: agent/runner/runner.proto
|
||||
|
||||
package runner
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
ComputationRunner_Run_FullMethodName = "/runner.ComputationRunner/Run"
|
||||
ComputationRunner_Stop_FullMethodName = "/runner.ComputationRunner/Stop"
|
||||
)
|
||||
|
||||
// ComputationRunnerClient is the client API for ComputationRunner service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type ComputationRunnerClient interface {
|
||||
Run(ctx context.Context, in *RunRequest, opts ...grpc.CallOption) (*RunResponse, error)
|
||||
Stop(ctx context.Context, in *StopRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
}
|
||||
|
||||
type computationRunnerClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewComputationRunnerClient(cc grpc.ClientConnInterface) ComputationRunnerClient {
|
||||
return &computationRunnerClient{cc}
|
||||
}
|
||||
|
||||
func (c *computationRunnerClient) Run(ctx context.Context, in *RunRequest, opts ...grpc.CallOption) (*RunResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RunResponse)
|
||||
err := c.cc.Invoke(ctx, ComputationRunner_Run_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *computationRunnerClient) Stop(ctx context.Context, in *StopRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, ComputationRunner_Stop_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ComputationRunnerServer is the server API for ComputationRunner service.
|
||||
// All implementations must embed UnimplementedComputationRunnerServer
|
||||
// for forward compatibility.
|
||||
type ComputationRunnerServer interface {
|
||||
Run(context.Context, *RunRequest) (*RunResponse, error)
|
||||
Stop(context.Context, *StopRequest) (*emptypb.Empty, error)
|
||||
mustEmbedUnimplementedComputationRunnerServer()
|
||||
}
|
||||
|
||||
// UnimplementedComputationRunnerServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedComputationRunnerServer struct{}
|
||||
|
||||
func (UnimplementedComputationRunnerServer) Run(context.Context, *RunRequest) (*RunResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method Run not implemented")
|
||||
}
|
||||
func (UnimplementedComputationRunnerServer) Stop(context.Context, *StopRequest) (*emptypb.Empty, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method Stop not implemented")
|
||||
}
|
||||
func (UnimplementedComputationRunnerServer) mustEmbedUnimplementedComputationRunnerServer() {}
|
||||
func (UnimplementedComputationRunnerServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeComputationRunnerServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to ComputationRunnerServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeComputationRunnerServer interface {
|
||||
mustEmbedUnimplementedComputationRunnerServer()
|
||||
}
|
||||
|
||||
func RegisterComputationRunnerServer(s grpc.ServiceRegistrar, srv ComputationRunnerServer) {
|
||||
// If the following call panics, it indicates UnimplementedComputationRunnerServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&ComputationRunner_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _ComputationRunner_Run_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RunRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ComputationRunnerServer).Run(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ComputationRunner_Run_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ComputationRunnerServer).Run(ctx, req.(*RunRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ComputationRunner_Stop_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StopRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ComputationRunnerServer).Stop(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ComputationRunner_Stop_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ComputationRunnerServer).Stop(ctx, req.(*StopRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// ComputationRunner_ServiceDesc is the grpc.ServiceDesc for ComputationRunner service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var ComputationRunner_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "runner.ComputationRunner",
|
||||
HandlerType: (*ComputationRunnerServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Run",
|
||||
Handler: _ComputationRunner_Run_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Stop",
|
||||
Handler: _ComputationRunner_Stop_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "agent/runner/runner.proto",
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/binary"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/docker"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/wasm"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
pb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
const (
|
||||
algoFilePermission = 0o700
|
||||
)
|
||||
|
||||
var _ pb.ComputationRunnerServer = (*RunnerService)(nil)
|
||||
|
||||
type RunnerService struct {
|
||||
pb.UnimplementedComputationRunnerServer
|
||||
logger *slog.Logger
|
||||
eventSvc events.Service
|
||||
currentAlgo algorithm.Algorithm
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func New(logger *slog.Logger, eventSvc events.Service) *RunnerService {
|
||||
return &RunnerService{
|
||||
logger: logger,
|
||||
eventSvc: eventSvc,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RunnerService) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) {
|
||||
s.mu.Lock()
|
||||
if s.currentAlgo != nil {
|
||||
s.mu.Unlock()
|
||||
return &pb.RunResponse{
|
||||
ComputationId: req.ComputationId,
|
||||
Error: "computation already running",
|
||||
}, nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.currentAlgo = nil
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
currentDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting current directory: %v", err)
|
||||
}
|
||||
|
||||
// Write Algo File
|
||||
algoPath := filepath.Join(currentDir, "algo")
|
||||
f, err := os.Create(algoPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating algorithm file: %v", err)
|
||||
}
|
||||
if _, err := f.Write(req.Algorithm); err != nil {
|
||||
return nil, fmt.Errorf("error writing algorithm to file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(algoPath, algoFilePermission); err != nil {
|
||||
return nil, fmt.Errorf("error changing file permissions: %v", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return nil, fmt.Errorf("error closing file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(algoPath); err != nil {
|
||||
s.logger.Warn("error removing algorithm file", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var algo algorithm.Algorithm
|
||||
|
||||
switch req.AlgoType {
|
||||
case string(algorithm.AlgoTypeBin):
|
||||
algo = binary.NewAlgorithm(s.logger, s.eventSvc, algoPath, req.Args, req.ComputationId)
|
||||
case string(algorithm.AlgoTypePython):
|
||||
var requirementsFile string
|
||||
if len(req.Requirements) > 0 {
|
||||
fr, err := os.CreateTemp("", "requirements.txt")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating requirments file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(fr.Name()); err != nil {
|
||||
s.logger.Warn("error removing requirements file", "error", err)
|
||||
}
|
||||
}()
|
||||
if _, err := fr.Write(req.Requirements); err != nil {
|
||||
return nil, fmt.Errorf("error writing requirements to file: %v", err)
|
||||
}
|
||||
if err := fr.Close(); err != nil {
|
||||
return nil, fmt.Errorf("error closing file: %v", err)
|
||||
}
|
||||
requirementsFile = fr.Name()
|
||||
}
|
||||
// Assuming default python runtime if not specified in request (proto doesn't have runtime field yet)
|
||||
// We can add it or assume.
|
||||
runtime := python.PyRuntime
|
||||
algo = python.NewAlgorithm(s.logger, s.eventSvc, runtime, requirementsFile, algoPath, req.Args, req.ComputationId)
|
||||
case string(algorithm.AlgoTypeWasm):
|
||||
algo = wasm.NewAlgorithm(s.logger, s.eventSvc, req.Args, algoPath, req.ComputationId)
|
||||
case string(algorithm.AlgoTypeDocker):
|
||||
algo = docker.NewAlgorithm(s.logger, s.eventSvc, algoPath, req.ComputationId)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported algorithm type: %s", req.AlgoType)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.currentAlgo = algo
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := algo.Run(); err != nil {
|
||||
s.logger.Error("computation failed", "error", err)
|
||||
return &pb.RunResponse{
|
||||
ComputationId: req.ComputationId,
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &pb.RunResponse{
|
||||
ComputationId: req.ComputationId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *RunnerService) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.currentAlgo != nil {
|
||||
if err := s.currentAlgo.Stop(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@@ -0,0 +1,382 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
pb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
)
|
||||
|
||||
// MockEventService is a mock implementation of events.Service.
|
||||
type MockEventService struct {
|
||||
events []interface{}
|
||||
}
|
||||
|
||||
func (m *MockEventService) SendEvent(cmpID, event, status string, details json.RawMessage) {
|
||||
m.events = append(m.events, map[string]interface{}{
|
||||
"cmpID": cmpID,
|
||||
"event": event,
|
||||
"status": status,
|
||||
"details": details,
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewRunnerService tests the creation of a new runner service.
|
||||
func TestNewRunnerService(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
|
||||
rs := New(logger, eventSvc)
|
||||
require.NotNil(t, rs)
|
||||
assert.NotNil(t, rs.logger)
|
||||
assert.NotNil(t, rs.eventSvc)
|
||||
assert.Nil(t, rs.currentAlgo)
|
||||
}
|
||||
|
||||
// TestRunWithBinaryAlgorithm tests running a binary algorithm.
|
||||
func TestRunWithBinaryAlgorithm(t *testing.T) {
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
require.NoError(t, os.Chdir(tmpDir))
|
||||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-1",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\necho 'test'"),
|
||||
Args: []string{"arg1", "arg2"},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Empty(t, resp.Error)
|
||||
assert.Equal(t, "test-1", resp.ComputationId)
|
||||
}
|
||||
|
||||
// TestRunWithPythonAlgorithm tests running a Python algorithm.
|
||||
func TestRunWithPythonAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-python",
|
||||
AlgoType: "python",
|
||||
Algorithm: []byte("print('hello')"),
|
||||
Args: []string{},
|
||||
Requirements: []byte("numpy==2.2.0"),
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Empty(t, resp.Error)
|
||||
assert.Equal(t, "test-python", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithPythonAlgorithmNoRequirements tests running Python without requirements.
|
||||
func TestRunWithPythonAlgorithmNoRequirements(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-python-noreq",
|
||||
AlgoType: "python",
|
||||
Algorithm: []byte("print('hello')"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Empty(t, resp.Error)
|
||||
assert.Equal(t, "test-python-noreq", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithWasmAlgorithm tests running a WASM algorithm.
|
||||
func TestRunWithWasmAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-wasm",
|
||||
AlgoType: "wasm",
|
||||
Algorithm: []byte{0x00, 0x61, 0x73, 0x6d},
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
if resp.Error != "" {
|
||||
assert.Contains(t, resp.Error, "wasmedge")
|
||||
t.Skip("wasmedge not found, skipping test")
|
||||
}
|
||||
assert.Equal(t, "test-wasm", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithDockerAlgorithm tests running a Docker algorithm.
|
||||
func TestRunWithDockerAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-docker",
|
||||
AlgoType: "docker",
|
||||
Algorithm: []byte("FROM ubuntu:latest\nRUN echo 'test'"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
if resp.Error != "" {
|
||||
assert.Contains(t, resp.Error, "Docker")
|
||||
t.Skip("Docker issue, skipping test")
|
||||
}
|
||||
assert.Equal(t, "test-docker", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithUnsupportedAlgorithmType tests running with unsupported algorithm type.
|
||||
func TestRunWithUnsupportedAlgorithmType(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-unsupported",
|
||||
AlgoType: "unsupported",
|
||||
Algorithm: []byte("test"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, resp)
|
||||
}
|
||||
|
||||
// TestRunAlreadyRunning tests running computation when one is already running.
|
||||
func TestRunAlreadyRunning(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
// Use a long-running bash script
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-running",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\nsleep 30"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
// Start first computation (will run for 30 seconds)
|
||||
go func() {
|
||||
_, _ = rs.Run(context.Background(), req)
|
||||
}()
|
||||
|
||||
// Give it time to start
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Try to run another immediately - should fail
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "computation already running", resp.Error)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestStopWhenRunning tests stopping a running computation.
|
||||
func TestStopWhenRunning(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-stop",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\nsleep 10"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = rs.Run(context.Background(), req)
|
||||
}()
|
||||
|
||||
// Give it time to start
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
stopReq := &pb.StopRequest{
|
||||
ComputationId: "test-stop",
|
||||
}
|
||||
|
||||
stopResp, err := rs.Stop(context.Background(), stopReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stopResp)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunErrors tests error paths in Run.
|
||||
func TestRunErrors(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
t.Run("create algo file failure", func(t *testing.T) {
|
||||
// Create a directory named "algo" to make os.Create("algo") fail
|
||||
err := os.Mkdir("algo", 0o755)
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll("algo")
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-err",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("test"),
|
||||
}
|
||||
_, err = rs.Run(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error creating algorithm file")
|
||||
})
|
||||
|
||||
t.Run("getwd failure", func(t *testing.T) {
|
||||
origDir, _ := os.Getwd()
|
||||
tmpDir := t.TempDir()
|
||||
err := os.Chdir(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove the current working directory to trigger Getwd failure
|
||||
err = os.RemoveAll(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-err-getwd",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("test"),
|
||||
}
|
||||
_, err = rs.Run(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error getting current directory")
|
||||
|
||||
// Restore working directory
|
||||
_ = os.Chdir(origDir)
|
||||
})
|
||||
|
||||
t.Run("requirements file creation failure", func(t *testing.T) {
|
||||
// This one is harder because it uses os.CreateTemp("", "requirements.txt")
|
||||
// We can't easily make this fail without reaching into the system's temp dir.
|
||||
// Skipping for now as it's a very unlikely edge case.
|
||||
})
|
||||
|
||||
t.Run("chmod failure", func(t *testing.T) {
|
||||
// We can't easily mock os.Chmod, but we can try to make the file unmodifiable
|
||||
// On Linux, we can set the immutable attribute, but that requires root.
|
||||
// Alternatively, we can try to use a directory with permissions that prevent chmod?
|
||||
// No, chmod usually works if you own the file.
|
||||
})
|
||||
|
||||
t.Run("write algorithm failure", func(t *testing.T) {
|
||||
// This is also hard without mocking os.File.Write or reaching internal limits.
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrentRun tests that concurrent runs are properly serialized.
|
||||
func TestConcurrentRun(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-concurrent",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\nsleep 15"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
// Start first run in goroutine (will run for 15 seconds)
|
||||
go func() {
|
||||
_, _ = rs.Run(context.Background(), req)
|
||||
}()
|
||||
|
||||
// Give it time to actually start
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Concurrent attempt should fail
|
||||
resp2, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "computation already running", resp2.Error)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRunWithMultipleArgs tests running with multiple arguments.
|
||||
func TestRunWithMultipleArgs(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-multi-args",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\necho $@"),
|
||||
Args: []string{"arg1", "arg2", "arg3", "arg4"},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Empty(t, resp.Error)
|
||||
assert.Equal(t, "test-multi-args", resp.ComputationId)
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove("algo")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopFailure(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
// Mock an algorithm that fails on Stop
|
||||
rs.currentAlgo = &MockAlgorithmStopFail{}
|
||||
|
||||
_, err := rs.Stop(context.Background(), &pb.StopRequest{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
type MockAlgorithmStopFail struct{}
|
||||
|
||||
func (m *MockAlgorithmStopFail) Run() error { return nil }
|
||||
func (m *MockAlgorithmStopFail) Stop() error { return fmt.Errorf("stop failed") }
|
||||
+857
-102
File diff suppressed because it is too large
Load Diff
+1435
-29
File diff suppressed because it is too large
Load Diff
+9
-3
@@ -54,10 +54,11 @@ func TestAddTransition(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != context.Canceled {
|
||||
t.Errorf("Start returned error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
sm.SendEvent(Event1)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
@@ -79,7 +80,7 @@ func TestSetAction(t *testing.T) {
|
||||
|
||||
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -88,8 +89,12 @@ func TestSetAction(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
sm.SendEvent(Event1)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
@@ -132,10 +137,11 @@ func TestMultipleTransitions(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != context.Canceled {
|
||||
t.Errorf("Start returned error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
transitions := []struct {
|
||||
event MockEvent
|
||||
want MockState
|
||||
|
||||
@@ -1,17 +1,33 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
"context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
statemachine "github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
)
|
||||
|
||||
// NewStateMachine creates a new instance of StateMachine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewStateMachine(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *StateMachine {
|
||||
mock := &StateMachine{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// StateMachine is an autogenerated mock type for the StateMachine type
|
||||
type StateMachine struct {
|
||||
mock.Mock
|
||||
@@ -25,9 +41,10 @@ func (_m *StateMachine) EXPECT() *StateMachine_Expecter {
|
||||
return &StateMachine_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AddTransition provides a mock function with given fields: t
|
||||
func (_m *StateMachine) AddTransition(t statemachine.Transition) {
|
||||
_m.Called(t)
|
||||
// AddTransition provides a mock function for the type StateMachine
|
||||
func (_mock *StateMachine) AddTransition(t statemachine.Transition) {
|
||||
_mock.Called(t)
|
||||
return
|
||||
}
|
||||
|
||||
// StateMachine_AddTransition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddTransition'
|
||||
@@ -43,7 +60,13 @@ func (_e *StateMachine_Expecter) AddTransition(t interface{}) *StateMachine_AddT
|
||||
|
||||
func (_c *StateMachine_AddTransition_Call) Run(run func(t statemachine.Transition)) *StateMachine_AddTransition_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(statemachine.Transition))
|
||||
var arg0 statemachine.Transition
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(statemachine.Transition)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@@ -53,28 +76,27 @@ func (_c *StateMachine_AddTransition_Call) Return() *StateMachine_AddTransition_
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *StateMachine_AddTransition_Call) RunAndReturn(run func(statemachine.Transition)) *StateMachine_AddTransition_Call {
|
||||
func (_c *StateMachine_AddTransition_Call) RunAndReturn(run func(t statemachine.Transition)) *StateMachine_AddTransition_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetState provides a mock function with no fields
|
||||
func (_m *StateMachine) GetState() statemachine.State {
|
||||
ret := _m.Called()
|
||||
// GetState provides a mock function for the type StateMachine
|
||||
func (_mock *StateMachine) GetState() statemachine.State {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetState")
|
||||
}
|
||||
|
||||
var r0 statemachine.State
|
||||
if rf, ok := ret.Get(0).(func() statemachine.State); ok {
|
||||
r0 = rf()
|
||||
if returnFunc, ok := ret.Get(0).(func() statemachine.State); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(statemachine.State)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -95,8 +117,8 @@ func (_c *StateMachine_GetState_Call) Run(run func()) *StateMachine_GetState_Cal
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *StateMachine_GetState_Call) Return(_a0 statemachine.State) *StateMachine_GetState_Call {
|
||||
_c.Call.Return(_a0)
|
||||
func (_c *StateMachine_GetState_Call) Return(state statemachine.State) *StateMachine_GetState_Call {
|
||||
_c.Call.Return(state)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -105,9 +127,10 @@ func (_c *StateMachine_GetState_Call) RunAndReturn(run func() statemachine.State
|
||||
return _c
|
||||
}
|
||||
|
||||
// Reset provides a mock function with given fields: initialState
|
||||
func (_m *StateMachine) Reset(initialState statemachine.State) {
|
||||
_m.Called(initialState)
|
||||
// Reset provides a mock function for the type StateMachine
|
||||
func (_mock *StateMachine) Reset(initialState statemachine.State) {
|
||||
_mock.Called(initialState)
|
||||
return
|
||||
}
|
||||
|
||||
// StateMachine_Reset_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Reset'
|
||||
@@ -123,7 +146,13 @@ func (_e *StateMachine_Expecter) Reset(initialState interface{}) *StateMachine_R
|
||||
|
||||
func (_c *StateMachine_Reset_Call) Run(run func(initialState statemachine.State)) *StateMachine_Reset_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(statemachine.State))
|
||||
var arg0 statemachine.State
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(statemachine.State)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@@ -133,14 +162,15 @@ func (_c *StateMachine_Reset_Call) Return() *StateMachine_Reset_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *StateMachine_Reset_Call) RunAndReturn(run func(statemachine.State)) *StateMachine_Reset_Call {
|
||||
func (_c *StateMachine_Reset_Call) RunAndReturn(run func(initialState statemachine.State)) *StateMachine_Reset_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendEvent provides a mock function with given fields: event
|
||||
func (_m *StateMachine) SendEvent(event statemachine.Event) {
|
||||
_m.Called(event)
|
||||
// SendEvent provides a mock function for the type StateMachine
|
||||
func (_mock *StateMachine) SendEvent(event statemachine.Event) {
|
||||
_mock.Called(event)
|
||||
return
|
||||
}
|
||||
|
||||
// StateMachine_SendEvent_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendEvent'
|
||||
@@ -156,7 +186,13 @@ func (_e *StateMachine_Expecter) SendEvent(event interface{}) *StateMachine_Send
|
||||
|
||||
func (_c *StateMachine_SendEvent_Call) Run(run func(event statemachine.Event)) *StateMachine_SendEvent_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(statemachine.Event))
|
||||
var arg0 statemachine.Event
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(statemachine.Event)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@@ -166,14 +202,15 @@ func (_c *StateMachine_SendEvent_Call) Return() *StateMachine_SendEvent_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *StateMachine_SendEvent_Call) RunAndReturn(run func(statemachine.Event)) *StateMachine_SendEvent_Call {
|
||||
func (_c *StateMachine_SendEvent_Call) RunAndReturn(run func(event statemachine.Event)) *StateMachine_SendEvent_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAction provides a mock function with given fields: state, action
|
||||
func (_m *StateMachine) SetAction(state statemachine.State, action statemachine.Action) {
|
||||
_m.Called(state, action)
|
||||
// SetAction provides a mock function for the type StateMachine
|
||||
func (_mock *StateMachine) SetAction(state statemachine.State, action statemachine.Action) {
|
||||
_mock.Called(state, action)
|
||||
return
|
||||
}
|
||||
|
||||
// StateMachine_SetAction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetAction'
|
||||
@@ -190,7 +227,18 @@ func (_e *StateMachine_Expecter) SetAction(state interface{}, action interface{}
|
||||
|
||||
func (_c *StateMachine_SetAction_Call) Run(run func(state statemachine.State, action statemachine.Action)) *StateMachine_SetAction_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(statemachine.State), args[1].(statemachine.Action))
|
||||
var arg0 statemachine.State
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(statemachine.State)
|
||||
}
|
||||
var arg1 statemachine.Action
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(statemachine.Action)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@@ -200,26 +248,25 @@ func (_c *StateMachine_SetAction_Call) Return() *StateMachine_SetAction_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *StateMachine_SetAction_Call) RunAndReturn(run func(statemachine.State, statemachine.Action)) *StateMachine_SetAction_Call {
|
||||
func (_c *StateMachine_SetAction_Call) RunAndReturn(run func(state statemachine.State, action statemachine.Action)) *StateMachine_SetAction_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields: ctx
|
||||
func (_m *StateMachine) Start(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
// Start provides a mock function for the type StateMachine
|
||||
func (_mock *StateMachine) Start(ctx context.Context) error {
|
||||
ret := _mock.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Start")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = returnFunc(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -236,31 +283,23 @@ func (_e *StateMachine_Expecter) Start(ctx interface{}) *StateMachine_Start_Call
|
||||
|
||||
func (_c *StateMachine_Start_Call) Run(run func(ctx context.Context)) *StateMachine_Start_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context))
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *StateMachine_Start_Call) Return(_a0 error) *StateMachine_Start_Call {
|
||||
_c.Call.Return(_a0)
|
||||
func (_c *StateMachine_Start_Call) Return(err error) *StateMachine_Start_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *StateMachine_Start_Call) RunAndReturn(run func(context.Context) error) *StateMachine_Start_Call {
|
||||
func (_c *StateMachine_Start_Call) RunAndReturn(run func(ctx context.Context) error) *StateMachine_Start_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewStateMachine creates a new instance of StateMachine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewStateMachine(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *StateMachine {
|
||||
mock := &StateMachine{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -39,6 +39,7 @@ type stateMachine struct {
|
||||
transitions map[State]map[Event]State
|
||||
actions map[State]Action
|
||||
eventChan chan Event
|
||||
resetChan chan struct{}
|
||||
}
|
||||
|
||||
func NewStateMachine(initialState State) StateMachine {
|
||||
@@ -47,6 +48,7 @@ func NewStateMachine(initialState State) StateMachine {
|
||||
transitions: make(map[State]map[Event]State),
|
||||
actions: make(map[State]Action),
|
||||
eventChan: make(chan Event),
|
||||
resetChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,16 +76,31 @@ func (sm *stateMachine) GetState() State {
|
||||
}
|
||||
|
||||
func (sm *stateMachine) SendEvent(event Event) {
|
||||
sm.eventChan <- event
|
||||
sm.mu.Lock()
|
||||
eventChan := sm.eventChan
|
||||
sm.mu.Unlock()
|
||||
|
||||
select {
|
||||
case eventChan <- event:
|
||||
default:
|
||||
// Channel might be closed or full, ignore the event
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *stateMachine) Start(ctx context.Context) error {
|
||||
for {
|
||||
sm.mu.Lock()
|
||||
eventChan := sm.eventChan
|
||||
resetChan := sm.resetChan
|
||||
sm.mu.Unlock()
|
||||
|
||||
select {
|
||||
case event := <-sm.eventChan:
|
||||
case event := <-eventChan:
|
||||
if err := sm.handleEvent(event); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-resetChan:
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -100,8 +117,11 @@ func (sm *stateMachine) Reset(initialState State) {
|
||||
// Close the existing event channel to stop processing events
|
||||
close(sm.eventChan)
|
||||
|
||||
// Create a new event channel
|
||||
// Close the reset channel to signal Start() to restart
|
||||
close(sm.resetChan)
|
||||
|
||||
sm.eventChan = make(chan Event)
|
||||
sm.resetChan = make(chan struct{})
|
||||
}
|
||||
|
||||
func (sm *stateMachine) handleEvent(event Event) error {
|
||||
|
||||
@@ -0,0 +1,607 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package statemachine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testState string
|
||||
|
||||
func (s testState) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
type testEvent string
|
||||
|
||||
func (e testEvent) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
const (
|
||||
StateIdle testState = "idle"
|
||||
StateRunning testState = "running"
|
||||
StatePaused testState = "paused"
|
||||
StateStopped testState = "stopped"
|
||||
StateError testState = "error"
|
||||
)
|
||||
|
||||
const (
|
||||
EventStart testEvent = "start"
|
||||
EventPause testEvent = "pause"
|
||||
EventStop testEvent = "stop"
|
||||
EventReset testEvent = "reset"
|
||||
EventError testEvent = "error"
|
||||
)
|
||||
|
||||
func TestNewStateMachine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
want State
|
||||
}{
|
||||
{
|
||||
name: "create with idle state",
|
||||
initialState: StateIdle,
|
||||
want: StateIdle,
|
||||
},
|
||||
{
|
||||
name: "create with running state",
|
||||
initialState: StateRunning,
|
||||
want: StateRunning,
|
||||
},
|
||||
{
|
||||
name: "create with custom state",
|
||||
initialState: testState("custom"),
|
||||
want: testState("custom"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState)
|
||||
if got := sm.GetState(); got != tt.want {
|
||||
t.Errorf("NewStateMachine() initial state = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_AddTransition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
transitions []Transition
|
||||
from State
|
||||
event Event
|
||||
expectTo State
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "single transition",
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
from: StateIdle,
|
||||
event: EventStart,
|
||||
expectTo: StateRunning,
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "multiple transitions from same state",
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateIdle, Event: EventError, To: StateError},
|
||||
},
|
||||
from: StateIdle,
|
||||
event: EventError,
|
||||
expectTo: StateError,
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "overwrite existing transition",
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateIdle, Event: EventStart, To: StatePaused}, // Overwrite
|
||||
},
|
||||
from: StateIdle,
|
||||
event: EventStart,
|
||||
expectTo: StatePaused,
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "transition not found",
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
from: StateRunning,
|
||||
event: EventPause,
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(StateIdle).(*stateMachine)
|
||||
|
||||
for _, transition := range tt.transitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
nextState, valid := sm.transitions[tt.from][tt.event]
|
||||
sm.mu.Unlock()
|
||||
|
||||
if valid != tt.expectValid {
|
||||
t.Errorf("Transition validity = %v, want %v", valid, tt.expectValid)
|
||||
}
|
||||
|
||||
if tt.expectValid && nextState != tt.expectTo {
|
||||
t.Errorf("Transition destination = %v, want %v", nextState, tt.expectTo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_SetAction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
state State
|
||||
action Action
|
||||
expectAction bool
|
||||
}{
|
||||
{
|
||||
name: "set action for state",
|
||||
state: StateRunning,
|
||||
action: func(s State) {
|
||||
},
|
||||
expectAction: true,
|
||||
},
|
||||
{
|
||||
name: "set nil action",
|
||||
state: StatePaused,
|
||||
action: nil,
|
||||
expectAction: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(StateIdle).(*stateMachine)
|
||||
sm.SetAction(tt.state, tt.action)
|
||||
|
||||
sm.mu.Lock()
|
||||
action := sm.actions[tt.state]
|
||||
sm.mu.Unlock()
|
||||
|
||||
if tt.expectAction && action == nil {
|
||||
t.Error("Expected action to be set, but it was nil")
|
||||
}
|
||||
if !tt.expectAction && action != nil {
|
||||
t.Error("Expected action to be nil, but it was set")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_GetState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
transitions []Transition
|
||||
events []Event
|
||||
finalState State
|
||||
}{
|
||||
{
|
||||
name: "get initial state",
|
||||
initialState: StateIdle,
|
||||
finalState: StateIdle,
|
||||
},
|
||||
{
|
||||
name: "get state after transition",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
events: []Event{EventStart},
|
||||
finalState: StateRunning,
|
||||
},
|
||||
{
|
||||
name: "get state after multiple transitions",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateRunning, Event: EventPause, To: StatePaused},
|
||||
{From: StatePaused, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
events: []Event{EventStart, EventPause, EventStart},
|
||||
finalState: StateRunning,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState)
|
||||
|
||||
for _, transition := range tt.transitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
smImpl := sm.(*stateMachine)
|
||||
for _, event := range tt.events {
|
||||
if err := smImpl.handleEvent(event); err != nil {
|
||||
t.Fatalf("Failed to handle event %v: %v", event, err)
|
||||
}
|
||||
}
|
||||
|
||||
if got := sm.GetState(); got != tt.finalState {
|
||||
t.Errorf("GetState() = %v, want %v", got, tt.finalState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_Start(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
transitions []Transition
|
||||
events []Event
|
||||
cancelAfter time.Duration
|
||||
expectError bool
|
||||
expectedStates []State
|
||||
}{
|
||||
{
|
||||
name: "start and cancel immediately",
|
||||
initialState: StateIdle,
|
||||
cancelAfter: 10 * time.Millisecond,
|
||||
expectError: true, // context.Canceled
|
||||
},
|
||||
{
|
||||
name: "process events then cancel",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateRunning, Event: EventStop, To: StateStopped},
|
||||
},
|
||||
events: []Event{EventStart, EventStop},
|
||||
cancelAfter: 100 * time.Millisecond,
|
||||
expectError: true, // context.Canceled
|
||||
expectedStates: []State{StateRunning, StateStopped},
|
||||
},
|
||||
{
|
||||
name: "invalid transition error",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
events: []Event{EventPause}, // Invalid from StateIdle
|
||||
cancelAfter: 50 * time.Millisecond,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState)
|
||||
|
||||
for _, transition := range tt.transitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
var states []State
|
||||
var mu sync.Mutex
|
||||
|
||||
for _, state := range tt.expectedStates {
|
||||
sm.SetAction(state, func(s State) {
|
||||
mu.Lock()
|
||||
states = append(states, s)
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- sm.Start(ctx)
|
||||
}()
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
for _, event := range tt.events {
|
||||
sm.SendEvent(event)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
|
||||
time.Sleep(tt.cancelAfter)
|
||||
cancel()
|
||||
|
||||
err := <-errChan
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
if len(states) != len(tt.expectedStates) {
|
||||
t.Errorf("Expected %d state changes, got %d", len(tt.expectedStates), len(states))
|
||||
}
|
||||
for i, expectedState := range tt.expectedStates {
|
||||
if i < len(states) && states[i] != expectedState {
|
||||
t.Errorf("State change %d = %v, want %v", i, states[i], expectedState)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_Reset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
resetState State
|
||||
setupTransitions []Transition
|
||||
eventsBeforeReset []Event
|
||||
eventsAfterReset []Event
|
||||
expectedState State
|
||||
}{
|
||||
{
|
||||
name: "reset to same state",
|
||||
initialState: StateIdle,
|
||||
resetState: StateIdle,
|
||||
expectedState: StateIdle,
|
||||
},
|
||||
{
|
||||
name: "reset to different state",
|
||||
initialState: StateIdle,
|
||||
resetState: StateRunning,
|
||||
expectedState: StateRunning,
|
||||
},
|
||||
{
|
||||
name: "reset after state changes",
|
||||
initialState: StateIdle,
|
||||
resetState: StateIdle,
|
||||
setupTransitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
eventsBeforeReset: []Event{EventStart},
|
||||
expectedState: StateIdle,
|
||||
},
|
||||
{
|
||||
name: "reset and send new events",
|
||||
initialState: StateIdle,
|
||||
resetState: StateIdle,
|
||||
setupTransitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateRunning, Event: EventStop, To: StateStopped},
|
||||
},
|
||||
eventsBeforeReset: []Event{EventStart},
|
||||
eventsAfterReset: []Event{EventStart},
|
||||
expectedState: StateIdle,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState)
|
||||
smImpl := sm.(*stateMachine)
|
||||
|
||||
for _, transition := range tt.setupTransitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
for _, event := range tt.eventsBeforeReset {
|
||||
if err := smImpl.handleEvent(event); err != nil {
|
||||
// Ignore errors for this test
|
||||
}
|
||||
}
|
||||
|
||||
sm.Reset(tt.resetState)
|
||||
|
||||
if got := sm.GetState(); got != tt.expectedState {
|
||||
t.Errorf("State after reset = %v, want %v", got, tt.expectedState)
|
||||
}
|
||||
|
||||
for _, event := range tt.eventsAfterReset {
|
||||
sm.SendEvent(event)
|
||||
}
|
||||
|
||||
// For events after reset, we can't easily check the channel length
|
||||
// due to the synchronization changes, so we just verify the reset worked
|
||||
if len(tt.eventsAfterReset) > 0 {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_Reset_WithRunningStateMachine(t *testing.T) {
|
||||
sm := NewStateMachine(StateIdle)
|
||||
sm.AddTransition(Transition{From: StateIdle, Event: EventStart, To: StateRunning})
|
||||
sm.AddTransition(Transition{From: StateRunning, Event: EventStop, To: StateStopped})
|
||||
|
||||
var stateChanges []State
|
||||
var mu sync.Mutex
|
||||
|
||||
sm.SetAction(StateRunning, func(s State) {
|
||||
mu.Lock()
|
||||
stateChanges = append(stateChanges, s)
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
sm.SetAction(StateStopped, func(s State) {
|
||||
mu.Lock()
|
||||
stateChanges = append(stateChanges, s)
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != nil {
|
||||
}
|
||||
}()
|
||||
|
||||
// Give it time to start
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
// Send an event
|
||||
sm.SendEvent(EventStart)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Reset while running
|
||||
sm.Reset(StateIdle)
|
||||
|
||||
// Verify state was reset
|
||||
if got := sm.GetState(); got != StateIdle {
|
||||
t.Errorf("State after reset = %v, want %v", got, StateIdle)
|
||||
}
|
||||
|
||||
// Send another event after reset
|
||||
sm.SendEvent(EventStart)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
changes := len(stateChanges)
|
||||
mu.Unlock()
|
||||
|
||||
// Should have at least processed the first event
|
||||
if changes < 1 {
|
||||
t.Errorf("Expected at least 1 state change, got %d", changes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_HandleEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
transitions []Transition
|
||||
event Event
|
||||
expectedState State
|
||||
expectError bool
|
||||
expectActionCall bool
|
||||
}{
|
||||
{
|
||||
name: "valid transition",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
event: EventStart,
|
||||
expectedState: StateRunning,
|
||||
expectError: false,
|
||||
expectActionCall: true,
|
||||
},
|
||||
{
|
||||
name: "invalid transition",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateRunning, Event: EventPause, To: StatePaused},
|
||||
},
|
||||
event: EventStart,
|
||||
expectedState: StateIdle,
|
||||
expectError: true,
|
||||
expectActionCall: false,
|
||||
},
|
||||
{
|
||||
name: "transition with no action",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
event: EventStart,
|
||||
expectedState: StateRunning,
|
||||
expectError: false,
|
||||
expectActionCall: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState).(*stateMachine)
|
||||
|
||||
for _, transition := range tt.transitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
var actionCalled bool
|
||||
var mu sync.Mutex
|
||||
|
||||
if tt.expectActionCall {
|
||||
sm.SetAction(tt.expectedState, func(s State) {
|
||||
mu.Lock()
|
||||
actionCalled = true
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
err := sm.handleEvent(tt.event)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if sm.GetState() != tt.expectedState {
|
||||
t.Errorf("State after handleEvent = %v, want %v", sm.GetState(), tt.expectedState)
|
||||
}
|
||||
|
||||
if tt.expectActionCall {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
mu.Lock()
|
||||
called := actionCalled
|
||||
mu.Unlock()
|
||||
if !called {
|
||||
t.Error("Expected action to be called but it wasn't")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_SendEvent_ThreadSafety(t *testing.T) {
|
||||
sm := NewStateMachine(StateIdle)
|
||||
sm.AddTransition(Transition{From: StateIdle, Event: EventStart, To: StateRunning})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != nil {
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
eventsPerGoroutine := 100
|
||||
|
||||
// Send events concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < eventsPerGoroutine; j++ {
|
||||
sm.SendEvent(EventStart)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// If we reach here without panicking, the test passes
|
||||
}
|
||||
@@ -1,9 +1,6 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build embed
|
||||
// +build embed
|
||||
|
||||
package cocosai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
# CoRIM Generation CLI Commands
|
||||
|
||||
This document describes the CLI commands for generating CoRIM (Concise Reference Integrity Manifest) attestation policies.
|
||||
|
||||
## Overview
|
||||
|
||||
The `cocos-cli policy create-corim` command provides subcommands for generating CoRIM policies for different platforms:
|
||||
- **azure**: Generate from Azure Attestation Token
|
||||
- **gcp**: Generate from GCP endorsements
|
||||
- **snp**: Generate for AMD SEV-SNP (direct host generation)
|
||||
- **tdx**: Generate for Intel TDX (direct host generation)
|
||||
|
||||
## Commands
|
||||
|
||||
### Azure SEV-SNP
|
||||
|
||||
Generate CoRIM from an Azure Attestation Token (JWT).
|
||||
|
||||
```bash
|
||||
cocos-cli policy create-corim azure --token <path-to-token> [--product <product>]
|
||||
```
|
||||
|
||||
**Flags:**
|
||||
- `--token` (required): Path to file containing Azure Attestation Token (JWT)
|
||||
- `--product` (optional): Processor product name (default: "Milan")
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
cocos-cli policy create-corim azure \
|
||||
--token /path/to/token.jwt \
|
||||
--product Milan \
|
||||
> azure-policy.corim
|
||||
```
|
||||
|
||||
### GCP SEV-SNP
|
||||
|
||||
Generate CoRIM from GCP SEV-SNP measurement and endorsements.
|
||||
|
||||
```bash
|
||||
cocos-cli policy create-corim gcp --measurement <hex> [--vcpu <num>]
|
||||
```
|
||||
|
||||
**Flags:**
|
||||
- `--measurement` (required): 384-bit measurement hex string
|
||||
- `--vcpu` (optional): vCPU number (default: 0)
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
cocos-cli policy create-corim gcp \
|
||||
--measurement abc123... \
|
||||
--vcpu 0 \
|
||||
> gcp-policy.corim
|
||||
```
|
||||
|
||||
### SEV-SNP (Direct Host)
|
||||
|
||||
Generate CoRIM for AMD SEV-SNP platform directly on the host.
|
||||
|
||||
```bash
|
||||
cocos-cli policy create-corim snp [flags]
|
||||
```
|
||||
|
||||
**Flags:**
|
||||
- `--measurement` (optional): Measurement/Launch Digest (hex string, defaults to zero if not provided)
|
||||
- `--policy` (optional): SNP policy flags (default: 0)
|
||||
- `--svn` (optional): Security Version Number/TCB (default: 0)
|
||||
- `--product` (optional): Processor product name (default: "Milan")
|
||||
- `--host-data` (optional): Host data (hex string)
|
||||
- `--launch-tcb` (optional): Minimum launch TCB (default: 0)
|
||||
- `--output` (optional): Output file path (default: stdout)
|
||||
|
||||
**Examples:**
|
||||
|
||||
Generate with defaults (zeroed measurement):
|
||||
```bash
|
||||
cocos-cli policy create-corim snp \
|
||||
--product Milan \
|
||||
--output snp-policy.corim
|
||||
```
|
||||
|
||||
Generate with custom measurement:
|
||||
```bash
|
||||
cocos-cli policy create-corim snp \
|
||||
--measurement abc123def456... \
|
||||
--product Genoa \
|
||||
--svn 1 \
|
||||
--policy 0x30000 \
|
||||
--output snp-policy.corim
|
||||
```
|
||||
|
||||
Generate with host data and launch TCB:
|
||||
```bash
|
||||
cocos-cli policy create-corim snp \
|
||||
--measurement abc123... \
|
||||
--host-data deadbeef \
|
||||
--launch-tcb 1 \
|
||||
--output snp-policy.corim
|
||||
```
|
||||
|
||||
### TDX (Direct Host)
|
||||
|
||||
Generate CoRIM for Intel TDX platform directly on the host.
|
||||
|
||||
```bash
|
||||
cocos-cli policy create-corim tdx [flags]
|
||||
```
|
||||
|
||||
**Flags:**
|
||||
- `--measurement` (optional): MRTD measurement (hex string, uses default if not provided)
|
||||
- `--svn` (optional): Security Version Number (default: 0)
|
||||
- `--rtmrs` (optional): Comma-separated RTMRs (hex)
|
||||
- `--mr-seam` (optional): MRSEAM (hex)
|
||||
- `--output` (optional): Output file path (default: stdout)
|
||||
|
||||
**Examples:**
|
||||
|
||||
Generate with defaults (matches legacy script behavior):
|
||||
```bash
|
||||
cocos-cli policy create-corim tdx \
|
||||
--output tdx-policy.corim
|
||||
```
|
||||
|
||||
Generate with custom values:
|
||||
```bash
|
||||
cocos-cli policy create-corim tdx \
|
||||
--measurement abc123def456... \
|
||||
--rtmrs rtmr0,rtmr1,rtmr2,rtmr3 \
|
||||
--mr-seam 789abc... \
|
||||
--svn 2 \
|
||||
--output tdx-policy.corim
|
||||
```
|
||||
|
||||
## Signing CoRIMs
|
||||
|
||||
CoRIMs can be signed using a private key (COSE_Sign1). The generated output will be a COSE-wrapped CoRIM in CBOR format.
|
||||
|
||||
### Prerequisite: Generate Signing Key
|
||||
|
||||
You will need an EC private key (P-256) in PEM format. You can generate one using `openssl`:
|
||||
|
||||
```bash
|
||||
openssl ecparam -name prime256v1 -genkey -noout -out private-key.pem
|
||||
```
|
||||
|
||||
### Signing with CLI
|
||||
|
||||
Use the `--signing-key` flag to sign the CoRIM during generation.
|
||||
|
||||
**SNP Example:**
|
||||
```bash
|
||||
cocos-cli policy create-corim snp \
|
||||
--product Milan \
|
||||
--signing-key private-key.pem \
|
||||
--output signed-snp.corim
|
||||
```
|
||||
|
||||
**TDX Example:**
|
||||
```bash
|
||||
cocos-cli policy create-corim tdx \
|
||||
--signing-key private-key.pem \
|
||||
--output signed-tdx.corim
|
||||
```
|
||||
|
||||
### Verification
|
||||
|
||||
The output file is a standard COSE_Sign1 message containing the CoRIM. It can be verified using any tool that supports COSE and CoRIM verification, such as the [veraison/corim](https://github.com/veraison/corim) library.
|
||||
|
||||
## Output Format
|
||||
|
||||
All commands output CoRIM in CBOR (Concise Binary Object Representation) format. By default, output is written to stdout, allowing for piping:
|
||||
|
||||
```bash
|
||||
# Pipe to file
|
||||
cocos-cli policy create-corim snp --product Milan > policy.corim
|
||||
|
||||
# Pipe to another command
|
||||
cocos-cli policy create-corim tdx | base64
|
||||
|
||||
# Use --output flag
|
||||
cocos-cli policy create-corim snp --product Milan --output policy.corim
|
||||
```
|
||||
|
||||
## Integration with Manager
|
||||
|
||||
The manager service can dynamically generate CoRIM policies using the same underlying generator package. When `FetchAttestationPolicy` is called:
|
||||
|
||||
1. For SNP: Calculates IGVM measurement using the `igvmmeasure` binary
|
||||
2. Extracts host data and launch TCB from VM configuration
|
||||
3. Generates CoRIM using the `generator` package
|
||||
4. Returns CBOR-encoded CoRIM
|
||||
|
||||
## See Also
|
||||
|
||||
- [Generator Package Documentation](../pkg/attestation/generator/README.md)
|
||||
- [IGVM Measure Package Documentation](../pkg/attestation/igvmmeasure/README.md)
|
||||
- [Manager README](../manager/README.md)
|
||||
+6
-6
@@ -29,7 +29,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
|
||||
algorithm, err := os.Open(algorithmFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading algorithm file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading algorithm file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
if requirementsFile != "" {
|
||||
req, err = os.Open(requirementsFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading requirments file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading requirments file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer req.Close()
|
||||
@@ -57,7 +57,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
|
||||
privKeyFile, err := os.ReadFile(args[1])
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -65,14 +65,14 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
|
||||
privKey, err := decodeKey(pemBlock)
|
||||
if err != nil {
|
||||
printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
||||
|
||||
if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algorithm, req, privKey); err != nil {
|
||||
printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err)
|
||||
cli.printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+146
-832
File diff suppressed because it is too large
Load Diff
+29
-196
@@ -5,171 +5,38 @@ package cli
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/gcp"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type fieldType int
|
||||
|
||||
const (
|
||||
measurementField fieldType = iota
|
||||
hostDataField
|
||||
)
|
||||
|
||||
const (
|
||||
// 0o744 file permission gives RWX permission to the user and only the R permission to others.
|
||||
filePermission = 0o744
|
||||
// Length of the expected host data and measurement field in bytes.
|
||||
hostDataLength = 32
|
||||
measurementLength = 48
|
||||
)
|
||||
|
||||
var (
|
||||
errDecode = errors.New("base64 string could not be decoded")
|
||||
errDataLength = errors.New("data does not have an adequate length")
|
||||
errReadingAttestationPolicyFile = errors.New("error while reading the attestation policy file")
|
||||
errUnmarshalJSON = errors.New("failed to unmarshal json")
|
||||
errMarshalJSON = errors.New("failed to marshal json")
|
||||
errWriteFile = errors.New("failed to write to file")
|
||||
errAttestationPolicyField = errors.New("the specified field type does not exist in the attestation policy")
|
||||
isJsonAttestation bool
|
||||
// 0o744 file permission gives RWX permission to the user and only the R permission to others.
|
||||
filePermission os.FileMode = 0o744
|
||||
)
|
||||
|
||||
func (cli *CLI) NewAttestationPolicyCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "policy [command]",
|
||||
cmd := &cobra.Command{
|
||||
Use: "policy",
|
||||
Short: "Change attestation policy",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("Change attestation policy\n\n")
|
||||
fmt.Printf("Usage:\n %s [command]\n\n", cmd.CommandPath())
|
||||
fmt.Printf("Available Commands:\n")
|
||||
|
||||
// Filter out "completion" command
|
||||
availableCommands := make([]*cobra.Command, 0)
|
||||
for _, subCmd := range cmd.Commands() {
|
||||
if subCmd.Name() != "completion" {
|
||||
availableCommands = append(availableCommands, subCmd)
|
||||
}
|
||||
}
|
||||
|
||||
for _, subCmd := range availableCommands {
|
||||
fmt.Printf(" %-15s%s\n", subCmd.Name(), subCmd.Short)
|
||||
}
|
||||
|
||||
fmt.Printf("\nFlags:\n")
|
||||
cmd.Flags().VisitAll(func(flag *pflag.Flag) {
|
||||
fmt.Printf(" -%s, --%s %s\n", flag.Shorthand, flag.Name, flag.Usage)
|
||||
})
|
||||
fmt.Printf("\nUse \"%s [command] --help\" for more information about a command.\n", cmd.CommandPath())
|
||||
_ = cmd.Help()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (cli *CLI) NewAddMeasurementCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "measurement",
|
||||
Short: "Add measurement to the attestation policy file. The value should be in base64. The second parameter is attestation_policy.json file",
|
||||
Example: "measurement <measurement> <attestation_policy.json>",
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := changeAttestationConfiguration(args[1], args[0], measurementLength, measurementField); err != nil {
|
||||
printError(cmd, "Error could not change measurement data: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
cmd.AddCommand(cli.NewCreateCoRIMCmd())
|
||||
|
||||
func (cli *CLI) NewAddHostDataCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "hostdata",
|
||||
Short: "Add host data to the attestation policy file. The value should be in base64. The second parameter is attestation_policy.json file",
|
||||
Example: "hostdata <host-data> <attestation_policy.json>",
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := changeAttestationConfiguration(args[1], args[0], hostDataLength, hostDataField); err != nil {
|
||||
printError(cmd, "Error could not change host data: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (cli *CLI) NewGCPAttestationPolicy() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "gcp",
|
||||
Short: "Get attestation policy for GCP CVM",
|
||||
Example: `gcp <bin_vtmp_attestation_report_file> <vcpu_count>`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
attestationBin, err := os.ReadFile(args[0])
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading attestation report file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
vcpuCount, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
printError(cmd, "Error converting vCPU count to integer: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
attestation := &attest.Attestation{}
|
||||
|
||||
if err := proto.Unmarshal(attestationBin, attestation); err != nil {
|
||||
printError(cmd, "Error unmarshaling attestation report: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
attestationPB := attestation.GetSevSnpAttestation()
|
||||
|
||||
measurement, err := gcp.Extract384BitMeasurement(attestationPB)
|
||||
if err != nil {
|
||||
printError(cmd, "Error extracting 384-bit measurement: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
launchEndorsement, err := gcp.GetLaunchEndorsement(cmd.Context(), measurement)
|
||||
if err != nil {
|
||||
printError(cmd, "Error getting launch endorsement: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
attestationPolicy, err := gcp.GenerateAttestationPolicy(launchEndorsement, uint32(vcpuCount))
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating attestation policy: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
attestationPolicyJson, err := json.MarshalIndent(attestationPolicy, "", " ")
|
||||
if err != nil {
|
||||
printError(cmd, "Error marshaling attestation policy: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile("attestation_policy.json", attestationPolicyJson, filePermission); err != nil {
|
||||
printError(cmd, "Error writing attestation policy file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("Attestation policy file generated successfully ✅")
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (cli *CLI) NewDownloadGCPOvmfFile() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
cmd := &cobra.Command{
|
||||
Use: "download",
|
||||
Short: "Download GCP OVMF file",
|
||||
Example: `download <bin_vtmp_attestation_report_file>`,
|
||||
@@ -177,95 +44,61 @@ func (cli *CLI) NewDownloadGCPOvmfFile() *cobra.Command {
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
attestationBin, err := os.ReadFile(args[0])
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading attestation report file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading attestation report file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
attestation := &attest.Attestation{}
|
||||
|
||||
if err := proto.Unmarshal(attestationBin, attestation); err != nil {
|
||||
printError(cmd, "Error unmarshaling attestation report: %v ❌ ", err)
|
||||
return
|
||||
if isJsonAttestation {
|
||||
if err := protojson.Unmarshal(attestationBin, attestation); err != nil {
|
||||
cli.printError(cmd, "Error converting JSON attestation to binary: %v ❌", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := proto.Unmarshal(attestationBin, attestation); err != nil {
|
||||
cli.printError(cmd, "Error unmarshaling attestation report: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
attestationPB := attestation.GetSevSnpAttestation()
|
||||
|
||||
measurement, err := gcp.Extract384BitMeasurement(attestationPB)
|
||||
if err != nil {
|
||||
printError(cmd, "Error extracting 384-bit measurement: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error extracting 384-bit measurement: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
launchEndorsement, err := gcp.GetLaunchEndorsement(cmd.Context(), measurement)
|
||||
if err != nil {
|
||||
printError(cmd, "Error getting launch endorsement: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error getting launch endorsement: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
ovmf, err := gcp.DownloadOvmfFile(cmd.Context(), fmt.Sprintf("%x", launchEndorsement.Digest))
|
||||
if err != nil {
|
||||
printError(cmd, "Error downloading OVMF file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error downloading OVMF file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
sum384 := sha512.Sum384(ovmf)
|
||||
|
||||
if !bytes.Equal(sum384[:], launchEndorsement.Digest) {
|
||||
printError(cmd, "Error OVMF file does not match the measurement: %v ❌ ", fmt.Errorf("digest mismatch"))
|
||||
cli.printError(cmd, "Error OVMF file does not match the measurement: %v ❌ ", fmt.Errorf("digest mismatch"))
|
||||
} else {
|
||||
cmd.Println("OVMF firmware in vm is unmodified ✅")
|
||||
}
|
||||
|
||||
if err := os.WriteFile("ovmf.fd", ovmf, filePermission); err != nil {
|
||||
printError(cmd, "Error writing OVMF file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error writing OVMF file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("OVMF file downloaded successfully ✅")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func changeAttestationConfiguration(fileName, base64Data string, expectedLength int, field fieldType) error {
|
||||
data, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
return errDecode
|
||||
}
|
||||
|
||||
if len(data) != expectedLength {
|
||||
return errDataLength
|
||||
}
|
||||
|
||||
ac := config.Config{Config: &check.Config{RootOfTrust: &check.RootOfTrust{}, Policy: &check.Policy{}}, PcrConfig: &config.PcrConfig{}}
|
||||
|
||||
f, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return errors.Wrap(errReadingAttestationPolicyFile, err)
|
||||
}
|
||||
|
||||
if err = config.ReadAttestationPolicyFromByte(f, &ac); err != nil {
|
||||
return errors.Wrap(errUnmarshalJSON, err)
|
||||
}
|
||||
|
||||
if ac.Config.Policy == nil {
|
||||
ac.Config.Policy = &check.Policy{}
|
||||
}
|
||||
|
||||
switch field {
|
||||
case measurementField:
|
||||
ac.Config.Policy.Measurement = data
|
||||
case hostDataField:
|
||||
ac.Config.Policy.HostData = data
|
||||
default:
|
||||
return errAttestationPolicyField
|
||||
}
|
||||
|
||||
fileJson, err := json.MarshalIndent(&ac, "", " ")
|
||||
if err != nil {
|
||||
return errors.Wrap(errMarshalJSON, err)
|
||||
}
|
||||
if err = os.WriteFile(fileName, fileJson, filePermission); err != nil {
|
||||
return errors.Wrap(errWriteFile, err)
|
||||
}
|
||||
return nil
|
||||
|
||||
cmd.Flags().BoolVarP(&isJsonAttestation, "json", "j", false, "Use JSON attestation report instead of binary")
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/corimgen"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/gcp"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/generator"
|
||||
)
|
||||
|
||||
func (cli *CLI) NewCreateCoRIMCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "create-corim",
|
||||
Short: "Create CoRIM attestation policy",
|
||||
Long: `Create CoRIM attestation policy for supported platforms (Azure, GCP, SNP, TDX)`,
|
||||
}
|
||||
|
||||
cmd.AddCommand(cli.NewCreateCoRIMAzureCmd())
|
||||
cmd.AddCommand(cli.NewCreateCoRIMGCPCmd())
|
||||
cmd.AddCommand(cli.NewCreateCoRIMSNPCmd())
|
||||
cmd.AddCommand(cli.NewCreateCoRIMTDXCmd())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (cli *CLI) NewCreateCoRIMAzureCmd() *cobra.Command {
|
||||
var tokenPath string
|
||||
var product string
|
||||
var output string
|
||||
var signingKeyPath string
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "azure",
|
||||
Short: "Create CoRIM for Azure SEV-SNP",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
tokenBytes, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
azureData, err := azure.ExtractAzureMeasurement(string(tokenBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract Azure measurements: %w", err)
|
||||
}
|
||||
|
||||
opts := generator.Options{
|
||||
Platform: "snp",
|
||||
Measurement: azureData.Measurement,
|
||||
HostData: azureData.HostData,
|
||||
Policy: azureData.Policy,
|
||||
SVN: azureData.SVN,
|
||||
Product: product,
|
||||
}
|
||||
|
||||
if signingKeyPath != "" {
|
||||
key, err := corimgen.LoadSigningKey(signingKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load signing key: %w", err)
|
||||
}
|
||||
opts.SigningKey = key
|
||||
}
|
||||
|
||||
cborBytes, err := generator.GenerateCoRIM(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate CoRIM: %w", err)
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
if err := os.WriteFile(output, cborBytes, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write output file: %w", err)
|
||||
}
|
||||
fmt.Fprintf(cmd.ErrOrStderr(), "CoRIM written to %s\n", output)
|
||||
} else {
|
||||
if _, err := cmd.OutOrStdout().Write(cborBytes); err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&tokenPath, "token", "", "Path to file containing Azure Attestation Token (JWT)")
|
||||
cmd.Flags().StringVar(&product, "product", "Milan", "Processor product name (Milan, Genoa)")
|
||||
cmd.Flags().StringVar(&output, "output", "", "Output file path (default: stdout)")
|
||||
cmd.Flags().StringVar(&signingKeyPath, "signing-key", "", "Path to private key for signing (PEM format)")
|
||||
_ = cmd.MarkFlagRequired("token")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (cli *CLI) NewCreateCoRIMGCPCmd() *cobra.Command {
|
||||
var measurement string
|
||||
var vcpuNum uint32
|
||||
var output string
|
||||
var signingKeyPath string
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "gcp",
|
||||
Short: "Create CoRIM for GCP SEV-SNP",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
endorsement, err := gcp.GetLaunchEndorsement(ctx, measurement)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get launch endorsement: %w", err)
|
||||
}
|
||||
|
||||
gcpData, err := gcp.ExtractGCPMeasurement(endorsement, vcpuNum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract GCP measurements: %w", err)
|
||||
}
|
||||
|
||||
opts := generator.Options{
|
||||
Platform: "snp",
|
||||
Measurement: gcpData.Measurement,
|
||||
Policy: gcpData.Policy,
|
||||
}
|
||||
|
||||
if signingKeyPath != "" {
|
||||
key, err := corimgen.LoadSigningKey(signingKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load signing key: %w", err)
|
||||
}
|
||||
opts.SigningKey = key
|
||||
}
|
||||
|
||||
cborBytes, err := generator.GenerateCoRIM(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate CoRIM: %w", err)
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
if err := os.WriteFile(output, cborBytes, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write output file: %w", err)
|
||||
}
|
||||
fmt.Fprintf(cmd.ErrOrStderr(), "CoRIM written to %s\n", output)
|
||||
} else {
|
||||
if _, err := cmd.OutOrStdout().Write(cborBytes); err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&measurement, "measurement", "", "384-bit measurement hex string")
|
||||
cmd.Flags().Uint32Var(&vcpuNum, "vcpu", 0, "vCPU number")
|
||||
cmd.Flags().StringVar(&output, "output", "", "Output file path (default: stdout)")
|
||||
cmd.Flags().StringVar(&signingKeyPath, "signing-key", "", "Path to private key for signing (PEM format)")
|
||||
_ = cmd.MarkFlagRequired("measurement")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (cli *CLI) NewCreateCoRIMSNPCmd() *cobra.Command {
|
||||
var (
|
||||
measurement string
|
||||
policy uint64
|
||||
svn uint64
|
||||
product string
|
||||
hostData string
|
||||
launchTCB uint64
|
||||
output string
|
||||
signingKeyPath string
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "snp",
|
||||
Short: "Create CoRIM for SEV-SNP",
|
||||
Long: `Generate CoRIM attestation policy for AMD SEV-SNP platform`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
opts := generator.Options{
|
||||
Platform: "snp",
|
||||
Measurement: measurement,
|
||||
Policy: policy,
|
||||
SVN: svn,
|
||||
Product: product,
|
||||
HostData: hostData,
|
||||
LaunchTCB: launchTCB,
|
||||
}
|
||||
|
||||
if signingKeyPath != "" {
|
||||
key, err := corimgen.LoadSigningKey(signingKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load signing key: %w", err)
|
||||
}
|
||||
opts.SigningKey = key
|
||||
}
|
||||
|
||||
cborBytes, err := generator.GenerateCoRIM(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate CoRIM: %w", err)
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
if err := os.WriteFile(output, cborBytes, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write output file: %w", err)
|
||||
}
|
||||
fmt.Fprintf(cmd.ErrOrStderr(), "CoRIM written to %s\n", output)
|
||||
} else {
|
||||
if _, err := cmd.OutOrStdout().Write(cborBytes); err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&measurement, "measurement", "", "Measurement/Launch Digest (hex string, defaults to zero if not provided)")
|
||||
cmd.Flags().Uint64Var(&policy, "policy", 0, "SNP policy flags")
|
||||
cmd.Flags().Uint64Var(&svn, "svn", 0, "Security Version Number (TCB)")
|
||||
cmd.Flags().StringVar(&product, "product", "Milan", "Processor product name (Milan, Genoa, etc.)")
|
||||
cmd.Flags().StringVar(&hostData, "host-data", "", "Host data (hex string)")
|
||||
cmd.Flags().Uint64Var(&launchTCB, "launch-tcb", 0, "Minimum launch TCB")
|
||||
cmd.Flags().StringVar(&output, "output", "", "Output file path (default: stdout)")
|
||||
cmd.Flags().StringVar(&signingKeyPath, "signing-key", "", "Path to private key for signing (PEM format)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (cli *CLI) NewCreateCoRIMTDXCmd() *cobra.Command {
|
||||
var (
|
||||
measurement string
|
||||
svn uint64
|
||||
rtmrs string
|
||||
mrSeam string
|
||||
output string
|
||||
signingKeyPath string
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "tdx",
|
||||
Short: "Create CoRIM for Intel TDX",
|
||||
Long: `Generate CoRIM attestation policy for Intel TDX platform`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
opts := generator.Options{
|
||||
Platform: "tdx",
|
||||
Measurement: measurement,
|
||||
SVN: svn,
|
||||
RTMRs: rtmrs,
|
||||
MrSeam: mrSeam,
|
||||
}
|
||||
|
||||
if signingKeyPath != "" {
|
||||
key, err := corimgen.LoadSigningKey(signingKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load signing key: %w", err)
|
||||
}
|
||||
opts.SigningKey = key
|
||||
}
|
||||
|
||||
cborBytes, err := generator.GenerateCoRIM(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate CoRIM: %w", err)
|
||||
}
|
||||
|
||||
if output != "" {
|
||||
if err := os.WriteFile(output, cborBytes, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write output file: %w", err)
|
||||
}
|
||||
fmt.Fprintf(cmd.ErrOrStderr(), "CoRIM written to %s\n", output)
|
||||
} else {
|
||||
if _, err := cmd.OutOrStdout().Write(cborBytes); err != nil {
|
||||
return fmt.Errorf("failed to write output: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&measurement, "measurement", "", "MRTD measurement (hex string, uses default if not provided)")
|
||||
cmd.Flags().Uint64Var(&svn, "svn", 0, "Security Version Number")
|
||||
cmd.Flags().StringVar(&rtmrs, "rtmrs", "", "Comma-separated RTMRs (hex)")
|
||||
cmd.Flags().StringVar(&mrSeam, "mr-seam", "", "MRSEAM (hex)")
|
||||
cmd.Flags().StringVar(&output, "output", "", "Output file path (default: stdout)")
|
||||
cmd.Flags().StringVar(&signingKeyPath, "signing-key", "", "Path to private key for signing (PEM format)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,389 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gce-tcb-verifier/proto/endorsement"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/gcp"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestCLI_NewCreateCoRIMCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMCmd()
|
||||
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "create-corim", cmd.Use)
|
||||
assert.True(t, cmd.HasSubCommands())
|
||||
|
||||
subcmds := cmd.Commands()
|
||||
assert.Equal(t, 4, len(subcmds))
|
||||
|
||||
cmdNames := make(map[string]bool)
|
||||
for _, sc := range subcmds {
|
||||
cmdNames[sc.Name()] = true
|
||||
}
|
||||
|
||||
assert.True(t, cmdNames["azure"])
|
||||
assert.True(t, cmdNames["gcp"])
|
||||
assert.True(t, cmdNames["snp"])
|
||||
assert.True(t, cmdNames["tdx"])
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMSNPCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMSNPCmd()
|
||||
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "snp", cmd.Use)
|
||||
|
||||
// Test with minimal flags
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetErr(&outBuf)
|
||||
cmd.SetArgs([]string{"--measurement", "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff"})
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, outBuf.Bytes())
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMTDXCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMTDXCmd()
|
||||
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "tdx", cmd.Use)
|
||||
|
||||
// Test with minimal flags
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetErr(&outBuf)
|
||||
cmd.SetArgs([]string{"--measurement", "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff"})
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, outBuf.Bytes())
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMAzureCmd_Error(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMAzureCmd()
|
||||
|
||||
// Missing token flag
|
||||
cmd.SetArgs([]string{})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
|
||||
// Non-existent token file
|
||||
cmd.SetArgs([]string{"--token", "non-existent-file"})
|
||||
err = cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to read token file")
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMGCPCmd_Error(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMGCPCmd()
|
||||
|
||||
// Missing measurement flag
|
||||
cmd.SetArgs([]string{})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
|
||||
// GCP command will fail because it tries to call Google Cloud Storage
|
||||
cmd.SetArgs([]string{"--measurement", "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff"})
|
||||
err = cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
// It should fail at GetLaunchEndorsement or storage client creation
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMAzureCmd_Success(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMAzureCmd()
|
||||
|
||||
oldValidator := azure.DefaultValidator
|
||||
defer func() { azure.DefaultValidator = oldValidator }()
|
||||
|
||||
azure.DefaultValidator = &mockTokenValidator{
|
||||
validateFunc: func(token string) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"x-ms-isolation-tee": map[string]any{
|
||||
"x-ms-sevsnpvm-launchmeasurement": "00112233",
|
||||
"x-ms-sevsnpvm-guestsvn": 1.0,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tokenPath := filepath.Join(tmpDir, "token.jwt")
|
||||
// Dummy token
|
||||
dummyToken := "eyJhbGciOiJub25lIn0.eyJoZWFkZXIiOiJkYXRhIn0."
|
||||
err := os.WriteFile(tokenPath, []byte(dummyToken), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetErr(&outBuf)
|
||||
cmd.SetArgs([]string{"--token", tokenPath})
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, outBuf.Bytes())
|
||||
|
||||
// Test with output file
|
||||
outputFile := filepath.Join(tmpDir, "azure-corim.cbor")
|
||||
cmd.SetArgs([]string{"--token", tokenPath, "--output", outputFile})
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
_, err = os.Stat(outputFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with signing key
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
err = os.WriteFile(keyPath, []byte("-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEIJ+3b6N6Y9J2H9f9X9X9X9X9X9X9X9X9X9X9X9X9X9X9\n-----END PRIVATE KEY-----"), 0o644)
|
||||
require.NoError(t, err)
|
||||
cmd.SetArgs([]string{"--token", tokenPath, "--signing-key", keyPath})
|
||||
err = cmd.Execute()
|
||||
assert.Error(t, err) // Should fail with invalid key but we cover the path
|
||||
// This might fail if the key is not valid Ed25519 for corimgen, but we want to cover the path
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMGCPCmd_More(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMGCPCmd()
|
||||
|
||||
oldNewStorageClient := gcp.NewStorageClient
|
||||
defer func() { gcp.NewStorageClient = oldNewStorageClient }()
|
||||
|
||||
gcp.NewStorageClient = func(ctx context.Context) (gcp.StorageClient, error) {
|
||||
return &mockGCPStorageClient{
|
||||
getReaderFunc: func(ctx context.Context, bucket, object string) (io.ReadCloser, error) {
|
||||
goldenUEFI := &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 123,
|
||||
Measurements: map[uint32][]byte{1: {0x1, 0x2}},
|
||||
},
|
||||
}
|
||||
goldenBytes, _ := proto.Marshal(goldenUEFI)
|
||||
launchEndorsement := &endorsement.VMLaunchEndorsement{
|
||||
SerializedUefiGolden: goldenBytes,
|
||||
}
|
||||
launchBytes, _ := proto.Marshal(launchEndorsement)
|
||||
return io.NopCloser(bytes.NewReader(launchBytes)), nil
|
||||
},
|
||||
closeFunc: func() error { return nil },
|
||||
}, nil
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputFile := filepath.Join(tmpDir, "gcp-corim.cbor")
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetErr(&outBuf)
|
||||
cmd.SetArgs([]string{"--measurement", "00112233", "--vcpu", "1", "--output", outputFile})
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
_, err = os.Stat(outputFile)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMSNPCmd_More(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMSNPCmd()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputFile := filepath.Join(tmpDir, "snp-corim.cbor")
|
||||
|
||||
cmd.SetArgs([]string{
|
||||
"--measurement", "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff",
|
||||
"--policy", "1",
|
||||
"--svn", "1",
|
||||
"--product", "Genoa",
|
||||
"--host-data", "00112233",
|
||||
"--launch-tcb", "1",
|
||||
"--output", outputFile,
|
||||
})
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
_, err = os.Stat(outputFile)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMTDXCmd_More(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMTDXCmd()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputFile := filepath.Join(tmpDir, "tdx-corim.cbor")
|
||||
|
||||
cmd.SetArgs([]string{
|
||||
"--measurement", "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff",
|
||||
"--svn", "1",
|
||||
"--rtmrs", "0011,2233",
|
||||
"--mr-seam", "aabbcc",
|
||||
"--output", outputFile,
|
||||
})
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
_, err = os.Stat(outputFile)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMCmd_Errors(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
t.Run("Azure fail to read token", func(t *testing.T) {
|
||||
cmd := cli.NewCreateCoRIMAzureCmd()
|
||||
cmd.SetArgs([]string{"--token", filepath.Join(tmpDir, "non-existent")})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to read token file")
|
||||
})
|
||||
|
||||
t.Run("Azure invalid signing key", func(t *testing.T) {
|
||||
cmd := cli.NewCreateCoRIMAzureCmd()
|
||||
oldValidator := azure.DefaultValidator
|
||||
defer func() { azure.DefaultValidator = oldValidator }()
|
||||
|
||||
azure.DefaultValidator = &mockTokenValidator{
|
||||
validateFunc: func(token string) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"x-ms-isolation-tee": map[string]any{
|
||||
"x-ms-sevsnpvm-launchmeasurement": "00112233",
|
||||
"x-ms-sevsnpvm-guestsvn": 1.0,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tokenPath := filepath.Join(tmpDir, "token.jwt")
|
||||
_ = os.WriteFile(tokenPath, []byte("token"), 0o644)
|
||||
cmd.SetArgs([]string{"--token", tokenPath, "--signing-key", filepath.Join(tmpDir, "non-existent")})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to load signing key")
|
||||
})
|
||||
|
||||
t.Run("GCP fail to load signing key", func(t *testing.T) {
|
||||
cmd := cli.NewCreateCoRIMGCPCmd()
|
||||
|
||||
oldNewStorageClient := gcp.NewStorageClient
|
||||
defer func() { gcp.NewStorageClient = oldNewStorageClient }()
|
||||
|
||||
gcp.NewStorageClient = func(ctx context.Context) (gcp.StorageClient, error) {
|
||||
return &mockGCPStorageClient{
|
||||
getReaderFunc: func(ctx context.Context, bucket, object string) (io.ReadCloser, error) {
|
||||
goldenUEFI := &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 123,
|
||||
Measurements: map[uint32][]byte{1: {0x1, 0x2}},
|
||||
},
|
||||
}
|
||||
goldenBytes, _ := proto.Marshal(goldenUEFI)
|
||||
launchEndorsement := &endorsement.VMLaunchEndorsement{
|
||||
SerializedUefiGolden: goldenBytes,
|
||||
}
|
||||
launchBytes, _ := proto.Marshal(launchEndorsement)
|
||||
return io.NopCloser(bytes.NewReader(launchBytes)), nil
|
||||
},
|
||||
closeFunc: func() error { return nil },
|
||||
}, nil
|
||||
}
|
||||
|
||||
cmd.SetArgs([]string{"--measurement", "0011", "--vcpu", "1", "--signing-key", filepath.Join(tmpDir, "non-existent")})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to load signing key")
|
||||
})
|
||||
|
||||
t.Run("SNP fail to load signing key", func(t *testing.T) {
|
||||
cmd := cli.NewCreateCoRIMSNPCmd()
|
||||
cmd.SetArgs([]string{"--measurement", "0011", "--signing-key", filepath.Join(tmpDir, "non-existent")})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to load signing key")
|
||||
})
|
||||
|
||||
t.Run("TDX fail to load signing key", func(t *testing.T) {
|
||||
cmd := cli.NewCreateCoRIMTDXCmd()
|
||||
cmd.SetArgs([]string{"--measurement", "0011", "--signing-key", filepath.Join(tmpDir, "non-existent")})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to load signing key")
|
||||
})
|
||||
}
|
||||
|
||||
type mockTokenValidator struct {
|
||||
validateFunc func(token string) (map[string]any, error)
|
||||
}
|
||||
|
||||
func (m *mockTokenValidator) Validate(token string) (map[string]any, error) {
|
||||
return m.validateFunc(token)
|
||||
}
|
||||
|
||||
func TestCLI_NewCreateCoRIMGCPCmd_Success(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewCreateCoRIMGCPCmd()
|
||||
|
||||
oldNewStorageClient := gcp.NewStorageClient
|
||||
defer func() { gcp.NewStorageClient = oldNewStorageClient }()
|
||||
|
||||
gcp.NewStorageClient = func(ctx context.Context) (gcp.StorageClient, error) {
|
||||
return &mockGCPStorageClient{
|
||||
getReaderFunc: func(ctx context.Context, bucket, object string) (io.ReadCloser, error) {
|
||||
goldenUEFI := &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 123,
|
||||
Measurements: map[uint32][]byte{1: {0x1, 0x2}},
|
||||
},
|
||||
}
|
||||
goldenBytes, _ := proto.Marshal(goldenUEFI)
|
||||
launchEndorsement := &endorsement.VMLaunchEndorsement{
|
||||
SerializedUefiGolden: goldenBytes,
|
||||
}
|
||||
launchBytes, _ := proto.Marshal(launchEndorsement)
|
||||
return io.NopCloser(bytes.NewReader(launchBytes)), nil
|
||||
},
|
||||
closeFunc: func() error { return nil },
|
||||
}, nil
|
||||
}
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetErr(&outBuf)
|
||||
cmd.SetArgs([]string{"--measurement", "00112233", "--vcpu", "1"})
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, outBuf.Bytes())
|
||||
}
|
||||
|
||||
type mockGCPStorageClient struct {
|
||||
getReaderFunc func(ctx context.Context, bucket, object string) (io.ReadCloser, error)
|
||||
closeFunc func() error
|
||||
}
|
||||
|
||||
func (m *mockGCPStorageClient) GetReader(ctx context.Context, bucket, object string) (io.ReadCloser, error) {
|
||||
return m.getReaderFunc(ctx, bucket, object)
|
||||
}
|
||||
|
||||
func (m *mockGCPStorageClient) Close() error {
|
||||
return m.closeFunc()
|
||||
}
|
||||
+96
-111
@@ -3,130 +3,115 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/gce-tcb-verifier/proto/endorsement"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/gcp"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestChangeAttestationConfiguration(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation_policy.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
initialConfig := config.Config{Config: &check.Config{RootOfTrust: &check.RootOfTrust{}, Policy: &check.Policy{}}, PcrConfig: &config.PcrConfig{}}
|
||||
|
||||
initialJSON, err := json.Marshal(initialConfig)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(tmpfile.Name(), initialJSON, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
base64Data string
|
||||
expectedLength int
|
||||
field fieldType
|
||||
expectError bool
|
||||
errorType error
|
||||
}{
|
||||
{
|
||||
name: "Valid Measurement",
|
||||
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength)),
|
||||
expectedLength: measurementLength,
|
||||
field: measurementField,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Host Data",
|
||||
base64Data: base64.StdEncoding.EncodeToString(make([]byte, hostDataLength)),
|
||||
expectedLength: hostDataLength,
|
||||
field: hostDataField,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid Base64",
|
||||
base64Data: "Invalid Base64",
|
||||
expectedLength: measurementLength,
|
||||
field: measurementField,
|
||||
expectError: true,
|
||||
errorType: errDecode,
|
||||
},
|
||||
{
|
||||
name: "Invalid Data Length",
|
||||
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength-1)),
|
||||
expectedLength: measurementLength,
|
||||
field: measurementField,
|
||||
expectError: true,
|
||||
errorType: errDataLength,
|
||||
},
|
||||
{
|
||||
name: "Invalid Field Type",
|
||||
base64Data: base64.StdEncoding.EncodeToString(make([]byte, measurementLength)),
|
||||
expectedLength: measurementLength,
|
||||
field: fieldType(999),
|
||||
expectError: true,
|
||||
errorType: errAttestationPolicyField,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := changeAttestationConfiguration(tmpfile.Name(), tt.base64Data, tt.expectedLength, tt.field)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, tt.errorType)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(tmpfile.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
ap := config.Config{Config: &check.Config{RootOfTrust: &check.RootOfTrust{}, Policy: &check.Policy{}}, PcrConfig: &config.PcrConfig{}}
|
||||
err = config.ReadAttestationPolicyFromByte(content, &ap)
|
||||
require.NoError(t, err)
|
||||
|
||||
decodedData, _ := base64.StdEncoding.DecodeString(tt.base64Data)
|
||||
if tt.field == measurementField {
|
||||
assert.Equal(t, decodedData, ap.Config.Policy.Measurement)
|
||||
} else if tt.field == hostDataField {
|
||||
assert.Equal(t, decodedData, ap.Config.Policy.HostData)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAttestationPolicyCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewAttestationPolicyCmd()
|
||||
c := &CLI{}
|
||||
cmd := c.NewAttestationPolicyCmd()
|
||||
|
||||
assert.Equal(t, "policy [command]", cmd.Use)
|
||||
assert.Equal(t, "policy", cmd.Use)
|
||||
assert.Equal(t, "Change attestation policy", cmd.Short)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
}
|
||||
|
||||
func TestNewAddMeasurementCmd(t *testing.T) {
|
||||
func TestCLI_NewDownloadGCPOvmfFile(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewAddMeasurementCmd()
|
||||
cmd := cli.NewDownloadGCPOvmfFile()
|
||||
|
||||
assert.Equal(t, "measurement", cmd.Use)
|
||||
assert.Equal(t, "Add measurement to the attestation policy file. The value should be in base64. The second parameter is attestation_policy.json file", cmd.Short)
|
||||
assert.Equal(t, "measurement <measurement> <attestation_policy.json>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
}
|
||||
|
||||
func TestNewAddHostDataCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewAddHostDataCmd()
|
||||
|
||||
assert.Equal(t, "hostdata", cmd.Use)
|
||||
assert.Equal(t, "Add host data to the attestation policy file. The value should be in base64. The second parameter is attestation_policy.json file", cmd.Short)
|
||||
assert.Equal(t, "hostdata <host-data> <attestation_policy.json>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "download", cmd.Use)
|
||||
|
||||
oldNewStorageClient := gcp.NewStorageClient
|
||||
defer func() { gcp.NewStorageClient = oldNewStorageClient }()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
attestationPath := filepath.Join(tmpDir, "attestation.bin")
|
||||
|
||||
// Change working directory to tmpDir so ovmf.fd is written there
|
||||
oldWd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
err = os.Chdir(tmpDir)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = os.Chdir(oldWd)
|
||||
}()
|
||||
|
||||
t.Run("invalid attestation file", func(t *testing.T) {
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetErr(&outBuf)
|
||||
cmd.SetArgs([]string{"non-existent"})
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err) // printError doesn't return error
|
||||
assert.Contains(t, outBuf.String(), "Error reading attestation report file")
|
||||
})
|
||||
|
||||
t.Run("successful download mock", func(t *testing.T) {
|
||||
// Mock storage client
|
||||
gcp.NewStorageClient = func(ctx context.Context) (gcp.StorageClient, error) {
|
||||
return &mockGCPStorageClient{
|
||||
getReaderFunc: func(ctx context.Context, bucket, object string) (io.ReadCloser, error) {
|
||||
if filepath.Base(object) == "ovmf_x64_csm.fd" || filepath.Ext(object) == ".fd" {
|
||||
data := make([]byte, 100)
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
// Return launch endorsement
|
||||
goldenUEFI := &endorsement.VMGoldenMeasurement{
|
||||
Digest: make([]byte, 48), // SHA384 size
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 123,
|
||||
},
|
||||
}
|
||||
goldenBytes, _ := proto.Marshal(goldenUEFI)
|
||||
launchEndorsement := &endorsement.VMLaunchEndorsement{
|
||||
SerializedUefiGolden: goldenBytes,
|
||||
}
|
||||
launchBytes, _ := proto.Marshal(launchEndorsement)
|
||||
return io.NopCloser(bytes.NewReader(launchBytes)), nil
|
||||
},
|
||||
closeFunc: func() error { return nil },
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create a mock binary attestation file.
|
||||
// It needs to be a valid attest.Attestation proto.
|
||||
att := &attest.Attestation{
|
||||
TeeAttestation: &attest.Attestation_SevSnpAttestation{
|
||||
SevSnpAttestation: &sevsnp.Attestation{
|
||||
Report: &sevsnp.Report{
|
||||
// Minimal report
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
attBytes, _ := proto.Marshal(att)
|
||||
err := os.WriteFile(attestationPath, attBytes, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
cmd.SetOut(&outBuf)
|
||||
cmd.SetErr(&outBuf)
|
||||
cmd.SetArgs([]string{attestationPath})
|
||||
|
||||
// This will still fail at gcp.Extract384BitMeasurement because report.Transform(attestation, "bin")
|
||||
// will likely fail on a nearly empty sevsnp.Attestation.
|
||||
// But let's see how it behaves.
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
// assert.Contains(t, outBuf.String(), "OVMF file downloaded successfully")
|
||||
})
|
||||
}
|
||||
|
||||
+46
-481
@@ -5,20 +5,13 @@ package cli
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
||||
)
|
||||
@@ -34,21 +27,16 @@ func TestNewAttestationCmd(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
|
||||
reportData := bytes.Repeat([]byte{0x01}, quoteprovider.Nonce)
|
||||
mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(reportData), mock.Anything).Return(nil)
|
||||
|
||||
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
|
||||
// Since NewAttestationCmd just prints help, we can check basic execution
|
||||
cmd.SetArgs([]string{"--help"})
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, buf.String(), "Get and validate attestations")
|
||||
}
|
||||
|
||||
func TestNewGetAttestationCmd(t *testing.T) {
|
||||
validattestation, err := os.ReadFile("../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
|
||||
teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))
|
||||
vtpmNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce))
|
||||
tokenNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -63,21 +51,21 @@ func TestNewGetAttestationCmd(t *testing.T) {
|
||||
args: []string{"snp", "--tee", teeNonce},
|
||||
mockResponse: []byte("mock attestation"),
|
||||
mockError: nil,
|
||||
expectedOut: "Attestation result retrieved and saved successfully!",
|
||||
expectedOut: "Attestation retrieved and saved successfully!",
|
||||
},
|
||||
{
|
||||
name: "successful vTPM attestation retrieval",
|
||||
args: []string{"vtpm", "--vtpm", vtpmNonce},
|
||||
mockResponse: []byte("mock attestation"),
|
||||
mockError: nil,
|
||||
expectedOut: "Attestation result retrieved and saved successfully!",
|
||||
expectedOut: "Attestation retrieved and saved successfully!",
|
||||
},
|
||||
{
|
||||
name: "successful SNP-vTPM attestation retrieval",
|
||||
args: []string{"snp-vtpm", "--tee", teeNonce, "--vtpm", vtpmNonce},
|
||||
mockResponse: []byte("mock attestation"),
|
||||
mockError: nil,
|
||||
expectedOut: "Attestation result retrieved and saved successfully!",
|
||||
expectedOut: "Attestation retrieved and saved successfully!",
|
||||
},
|
||||
{
|
||||
name: "missing vTPM nonce",
|
||||
@@ -102,7 +90,7 @@ func TestNewGetAttestationCmd(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid vTPM data size",
|
||||
args: []string{"vtpm", "-t", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 33))},
|
||||
args: []string{"vtpm", "--vtpm", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 33))},
|
||||
mockResponse: nil,
|
||||
mockError: errors.New("error"),
|
||||
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
|
||||
@@ -116,39 +104,48 @@ func TestNewGetAttestationCmd(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "failed to get attestation",
|
||||
args: []string{"snp", "-e", teeNonce},
|
||||
args: []string{"snp", "--tee", teeNonce},
|
||||
mockResponse: nil,
|
||||
mockError: errors.New("error"),
|
||||
expectedErr: "Failed to get attestation due to error",
|
||||
},
|
||||
{
|
||||
name: "Textproto report error",
|
||||
args: []string{"snp", "-e", teeNonce, "--textproto"},
|
||||
mockResponse: []byte("mock attestation"),
|
||||
mockError: nil,
|
||||
expectedErr: "Error converting attestation to textproto",
|
||||
},
|
||||
{
|
||||
name: "successful Textproto report",
|
||||
args: []string{"snp", "-e", teeNonce, "--textproto"},
|
||||
mockResponse: validattestation,
|
||||
mockError: nil,
|
||||
expectedOut: "Attestation result retrieved and saved successfully!",
|
||||
},
|
||||
{
|
||||
name: "connection error",
|
||||
args: []string{"snp", "-e", teeNonce},
|
||||
args: []string{"snp", "--tee", teeNonce},
|
||||
mockResponse: nil,
|
||||
mockError: errors.New("failed to connect to agent"),
|
||||
expectedErr: "Failed to connect to agent",
|
||||
},
|
||||
{
|
||||
name: "successful Azure token retrieval",
|
||||
args: []string{"azure-token", "--token", tokenNonce},
|
||||
mockResponse: []byte("eyJhbGciOiAiUlMyNTYifQ.eyJzdWIiOiAidGVzdC11c2VyIn0.signature"),
|
||||
mockError: nil,
|
||||
expectedOut: "Fetching Azure token\nAttestation retrieved and saved successfully!\n",
|
||||
},
|
||||
{
|
||||
name: "failed to retrieve Azure token",
|
||||
args: []string{"azure-token", "--token", tokenNonce},
|
||||
mockResponse: nil,
|
||||
mockError: errors.New("error"),
|
||||
expectedErr: "Fetching Azure token\nFailed to get attestation token due to error: error ❌\n",
|
||||
},
|
||||
{
|
||||
name: "invalid token nonce size",
|
||||
args: []string{"azure-token", "--token", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 33))},
|
||||
mockResponse: nil,
|
||||
mockError: errors.New("error"),
|
||||
expectedErr: "Fetching Azure token\nvTPM nonce must be a hex encoded string of length lesser or equal 32 bytes ❌ \n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
os.Remove(attestationFilePath)
|
||||
os.Remove(attestationJson)
|
||||
os.Remove(attestationReportJson)
|
||||
os.Remove(azureAttestResultFilePath)
|
||||
os.Remove(azureAttestTokenFilePath)
|
||||
})
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
@@ -159,11 +156,16 @@ func TestNewGetAttestationCmd(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
|
||||
mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
||||
mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
||||
_, err := args.Get(4).(*os.File).Write(tc.mockResponse)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
mockSDK.On("AttestationToken", mock.Anything, [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
||||
_, err := args.Get(3).(*os.File).Write(tc.mockResponse)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
cmd.SetArgs(tc.args)
|
||||
err := cmd.Execute()
|
||||
|
||||
@@ -177,452 +179,15 @@ func TestNewGetAttestationCmd(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewValidateAttestationValidationCmdDefaults(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
assert.Equal(t, "validate", cmd.Use)
|
||||
assert.Equal(t, "Validate and verify attestation information. You can choose from 3 modes: snp,vtpm and snp-vtpm.Default mode is snp.", cmd.Short)
|
||||
|
||||
assert.Equal(t, fmt.Sprint(defaultMinimumTcb), cmd.Flag("minimum_tcb").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultMinimumLaunchTcb), cmd.Flag("minimum_lauch_tcb").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultGuestPolicy), cmd.Flag("guest_policy").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultMinimumGuestSvn), cmd.Flag("minimum_guest_svn").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultMinimumBuild), cmd.Flag("minimum_build").Value.String())
|
||||
assert.Equal(t, defaultCheckCrl, cmd.Flag("check_crl").Value.String() == "true")
|
||||
assert.Equal(t, fmt.Sprint(defaultTimeout), cmd.Flag("timeout").Value.String())
|
||||
assert.Equal(t, fmt.Sprint(defaultMaxRetryDelay), cmd.Flag("max_retry_delay").Value.String())
|
||||
}
|
||||
|
||||
func TestNewValidateAttestationValidationCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
t.Run("missing attestation report file path", func(t *testing.T) {
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "please pass the attestation report file path", err.Error())
|
||||
})
|
||||
assert.Equal(t, "validate", cmd.Use)
|
||||
assert.Contains(t, cmd.Short, "Deprecated")
|
||||
|
||||
t.Run("unknown mode", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{attestationFilePath, "--mode=invalid"})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown mode")
|
||||
})
|
||||
|
||||
t.Run("snp mode with missing flags", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{attestationFilePath, "--mode=snp"})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "required flag(s) \"product\", \"report_data\" not set")
|
||||
})
|
||||
|
||||
t.Run("vtpm mode with missing flags", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{vtpmFilePath, "--mode=vtpm"})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "required flag(s) \"format\", \"nonce\", \"output\", \"product\", \"report_data\" not set")
|
||||
})
|
||||
|
||||
t.Run("snp-vtpm mode with missing flags", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{vtpmFilePath, "--mode=snp-vtpm"})
|
||||
err := cmd.Execute()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "required flag(s) \"format\", \"nonce\", \"output\", \"product\", \"report_data\" not set")
|
||||
})
|
||||
|
||||
t.Run("valid snp mode execution", func(t *testing.T) {
|
||||
cli := CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
cmd.RunE = func(_ *cobra.Command, _ []string) error {
|
||||
t.Log("Mock RunE executed instead of sevsnpverify")
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.SetArgs([]string{
|
||||
"../attestation.bin",
|
||||
"--mode=snp",
|
||||
"--report_data=" +
|
||||
"11223344556677889900aabbccddeeff11223344556677889900aabbccddeeff" +
|
||||
"11223344556677889900aabbccddeeff11223344556677889900aabbccddeeff",
|
||||
"--product=Milan",
|
||||
})
|
||||
err := cmd.PreRunE(cmd, []string{"../attestation.bin"})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("valid vtpm mode execution", func(t *testing.T) {
|
||||
cli := CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
cmd.RunE = func(_ *cobra.Command, _ []string) error {
|
||||
t.Log("Mock RunE executed instead of vtpmverify")
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.SetArgs([]string{vtpmFilePath, "--mode=vtpm", "--nonce=123abc", "--format=binarypb", "--output=some_output"})
|
||||
|
||||
err := cmd.PreRunE(cmd, []string{"../quote.dat"})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("valid snp-vtpm mode execution", func(t *testing.T) {
|
||||
cli := CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
cmd.RunE = func(_ *cobra.Command, _ []string) error {
|
||||
t.Log("Mock RunE executed instead of vtpmSevSnpverify")
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.SetArgs([]string{vtpmFilePath, "--mode=snp-vtpm", "--nonce=123abc", "--format=textproto", "--output=some_output"})
|
||||
err := cmd.PreRunE(cmd, []string{"../quote.dat"})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
type MockMeasurement struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockMeasurement) Run(igvmBinaryPath string) ([]byte, error) {
|
||||
args := m.Called(igvmBinaryPath)
|
||||
return nil, args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMeasurement) Stop() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNewMeasureCmd_RunSuccess(t *testing.T) {
|
||||
cliInstance := &CLI{}
|
||||
mockMeasurement := new(MockMeasurement)
|
||||
cliInstance.measurement = mockMeasurement
|
||||
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return(nil)
|
||||
|
||||
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
|
||||
buf := new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
cmd.SetErr(buf)
|
||||
cmd.SetArgs([]string{"testfile.igvm"})
|
||||
|
||||
err := cmd.Execute()
|
||||
|
||||
assert.NoError(t, err)
|
||||
mockMeasurement.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestNewMeasureCmd_RunError(t *testing.T) {
|
||||
cliInstance := &CLI{}
|
||||
mockMeasurement := new(MockMeasurement)
|
||||
cliInstance.measurement = mockMeasurement
|
||||
expectedError := errors.New("mocked measurement error")
|
||||
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return(expectedError)
|
||||
|
||||
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
cmd.SetErr(buf)
|
||||
cmd.SetArgs([]string{"testfile.igvm"})
|
||||
|
||||
err := cmd.Execute()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedError.Error(), err.Error())
|
||||
mockMeasurement.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
cfgString = ""
|
||||
err := parseConfig()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg.RootOfTrust)
|
||||
assert.NotNil(t, cfg.Policy)
|
||||
|
||||
cfgString = `{"rootOfTrust":{"product":"test_product"},"policy":{"minimumGuestSvn":1}}`
|
||||
err = parseConfig()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_product", cfg.RootOfTrust.Product)
|
||||
assert.Equal(t, uint32(1), cfg.Policy.MinimumGuestSvn)
|
||||
|
||||
cfgString = `{"invalid_json"`
|
||||
err = parseConfig()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseHashes(t *testing.T) {
|
||||
trustedAuthorHashes = []string{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}
|
||||
trustedIdKeyHashes = []string{"fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"}
|
||||
|
||||
cfg = check.Config{}
|
||||
if cfg.Policy == nil {
|
||||
cfg.Policy = &check.Policy{}
|
||||
}
|
||||
|
||||
err := parseHashes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, cfg.Policy.TrustedAuthorKeyHashes, 1)
|
||||
assert.Len(t, cfg.Policy.TrustedIdKeyHashes, 1)
|
||||
|
||||
trustedAuthorHashes = []string{"invalid_hash"}
|
||||
err = parseHashes()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseFiles(t *testing.T) {
|
||||
attestationFile = "test_attestation.bin"
|
||||
authorKeyFile := "test_author_key.pem"
|
||||
idKeyFile := "test_id_key.pem"
|
||||
|
||||
err := os.WriteFile(attestationFile, []byte("test attestation"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(authorKeyFile, []byte("test author key"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(idKeyFile, []byte("test id key"), 0o644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
trustedAuthorKeys = []string{authorKeyFile}
|
||||
trustedIdKeys = []string{idKeyFile}
|
||||
|
||||
err = parseAttestationFile()
|
||||
assert.NoError(t, err)
|
||||
err = parseTrustedKeys()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("test attestation"), attestation)
|
||||
assert.Len(t, cfg.Policy.TrustedAuthorKeys, 1)
|
||||
assert.Len(t, cfg.Policy.TrustedIdKeys, 1)
|
||||
|
||||
os.Remove(attestationFile)
|
||||
os.Remove(authorKeyFile)
|
||||
os.Remove(idKeyFile)
|
||||
|
||||
attestationFile = "non_existent_file.bin"
|
||||
err = parseAttestationFile()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseUints(t *testing.T) {
|
||||
stepping = "10"
|
||||
platformInfo = "0xFF"
|
||||
|
||||
cfg = check.Config{}
|
||||
if cfg.Policy == nil {
|
||||
cfg.Policy = &check.Policy{
|
||||
Product: &sevsnp.SevProduct{},
|
||||
}
|
||||
}
|
||||
err := parseUints()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(10), cfg.Policy.Product.MachineStepping.Value)
|
||||
assert.Equal(t, uint64(255), cfg.Policy.PlatformInfo.Value)
|
||||
|
||||
stepping = "invalid"
|
||||
err = parseUints()
|
||||
assert.Error(t, err)
|
||||
|
||||
stepping = "10"
|
||||
platformInfo = "invalid"
|
||||
err = parseUints()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
cfg = check.Config{}
|
||||
if cfg.Policy == nil {
|
||||
cfg.Policy = &check.Policy{}
|
||||
}
|
||||
if cfg.RootOfTrust == nil {
|
||||
cfg.RootOfTrust = &check.RootOfTrust{}
|
||||
}
|
||||
cfg.Policy.ReportData = make([]byte, 64)
|
||||
cfg.Policy.HostData = make([]byte, 32)
|
||||
cfg.Policy.FamilyId = make([]byte, 16)
|
||||
cfg.Policy.ImageId = make([]byte, 16)
|
||||
cfg.Policy.ReportId = make([]byte, 32)
|
||||
cfg.Policy.ReportIdMa = make([]byte, 32)
|
||||
cfg.Policy.Measurement = make([]byte, 48)
|
||||
cfg.Policy.ChipId = make([]byte, 64)
|
||||
|
||||
err := validateInput()
|
||||
assert.NoError(t, err)
|
||||
|
||||
cfg.Policy.ReportData = make([]byte, 32)
|
||||
err = validateInput()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetBase(t *testing.T) {
|
||||
assert.Equal(t, 16, getBase("0xFF"))
|
||||
assert.Equal(t, 8, getBase("0o77"))
|
||||
assert.Equal(t, 2, getBase("0b1010"))
|
||||
assert.Equal(t, 10, getBase("123"))
|
||||
}
|
||||
|
||||
func TestAttestationToJSON(t *testing.T) {
|
||||
validReport, err := os.ReadFile("../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid report",
|
||||
input: validReport,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid report size",
|
||||
input: make([]byte, abi.ReportSize-1),
|
||||
err: errReportSize,
|
||||
},
|
||||
{
|
||||
name: "Nil input",
|
||||
input: nil,
|
||||
err: errReportSize,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := attesationToJSON(tt.input)
|
||||
assert.True(t, errors.Contains(err, tt.err))
|
||||
if tt.err != nil {
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, got)
|
||||
|
||||
var js map[string]interface{}
|
||||
err = json.Unmarshal(got, &js)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestationFromJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
err error
|
||||
validate func(t *testing.T, output []byte)
|
||||
}{
|
||||
{
|
||||
name: "Valid JSON",
|
||||
input: func() []byte {
|
||||
att := &sevsnp.Attestation{
|
||||
Report: &sevsnp.Report{
|
||||
CurrentTcb: 1,
|
||||
FamilyId: make([]byte, 16),
|
||||
ImageId: make([]byte, 16),
|
||||
ReportData: make([]byte, 64),
|
||||
Measurement: make([]byte, 48),
|
||||
HostData: make([]byte, 32),
|
||||
IdKeyDigest: make([]byte, 48),
|
||||
AuthorKeyDigest: make([]byte, 48),
|
||||
ReportId: make([]byte, 32),
|
||||
ReportIdMa: make([]byte, 32),
|
||||
ChipId: make([]byte, 64),
|
||||
Signature: make([]byte, 512),
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(att)
|
||||
require.NoError(t, err)
|
||||
return data
|
||||
}(),
|
||||
err: nil,
|
||||
validate: func(t *testing.T, output []byte) {
|
||||
assert.NotEmpty(t, output)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
input: []byte(`{"invalid": json`),
|
||||
err: errors.New("invalid character 'j' looking for beginning of value"),
|
||||
validate: func(t *testing.T, output []byte) {
|
||||
assert.Nil(t, output)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Empty input",
|
||||
input: []byte{},
|
||||
err: errors.New("unexpected end of JSON input"),
|
||||
validate: func(t *testing.T, output []byte) {
|
||||
assert.Nil(t, output)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := attesationFromJSON(tt.input)
|
||||
assert.True(t, errors.Contains(err, tt.err))
|
||||
tt.validate(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsFileJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JSON extension",
|
||||
filename: "test.json",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Valid JSON extension with path",
|
||||
filename: "/path/to/test.json",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid extension",
|
||||
filename: "test.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "No extension",
|
||||
filename: "test",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "JSON in filename",
|
||||
filename: "json.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
filename: "",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isFileJSON(tt.filename)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
originalReport, err := os.ReadFile("../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
jsonData, err := attesationToJSON(originalReport)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, jsonData)
|
||||
|
||||
roundTripReport, err := attesationFromJSON(jsonData)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, roundTripReport)
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
_ = cmd.Execute()
|
||||
assert.Contains(t, buf.String(), "deprecated")
|
||||
}
|
||||
|
||||
+12
-24
@@ -9,10 +9,8 @@ import (
|
||||
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/kds"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/verify/trust"
|
||||
"github.com/spf13/cobra"
|
||||
config "github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,46 +18,36 @@ const (
|
||||
filePermisionKeys = 0o766
|
||||
)
|
||||
|
||||
func (cli *CLI) NewCABundleCmd(fileSavePath string) *cobra.Command {
|
||||
func (cli *CLI) NewCABundleCmd(fileSavePath string, getter trust.HTTPSGetter) *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "ca-bundle",
|
||||
Short: "Fetch AMD SEV-SNPs CA Bundle (ASK and ARK)",
|
||||
Example: "ca-bundle <path_to_platform_info_json>",
|
||||
Example: "ca-bundle <product_name>",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
attestationConfiguration := config.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &config.PcrConfig{}}
|
||||
err := config.ReadAttestationPolicy(args[0], &attestationConfiguration)
|
||||
if err != nil {
|
||||
printError(cmd, "Error while reading manifest: %v ❌ ", err)
|
||||
return
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
product := args[0]
|
||||
|
||||
if getter == nil {
|
||||
getter = trust.DefaultHTTPSGetter()
|
||||
}
|
||||
|
||||
product := attestationConfiguration.Config.RootOfTrust.ProductLine
|
||||
|
||||
getter := trust.DefaultHTTPSGetter()
|
||||
caURL := kds.ProductCertChainURL(abi.VcekReportSigner, product)
|
||||
|
||||
bundle, err := getter.Get(caURL)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("Error fetching ARK and ASK from AMD KDS for product: %s", product)
|
||||
message += ", error: %v ❌ "
|
||||
printError(cmd, message, err)
|
||||
return
|
||||
return fmt.Errorf("error fetching ARK and ASK from AMD KDS for product %s: %w", product, err)
|
||||
}
|
||||
|
||||
err = os.MkdirAll(path.Join(fileSavePath, product), filePermisionKeys)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("Error while creating directory for product name %s", product)
|
||||
message += ", error: %v ❌ "
|
||||
printError(cmd, message, err)
|
||||
return
|
||||
return fmt.Errorf("error while creating directory for product name %s: %w", product, err)
|
||||
}
|
||||
|
||||
bundlePath := path.Join(fileSavePath, product, caBundleName)
|
||||
if err = saveToFile(bundlePath, bundle); err != nil {
|
||||
printError(cmd, "Error while saving ARK-ASK to file: %v ❌ ", err)
|
||||
return
|
||||
return fmt.Errorf("error while saving ARK-ASK to file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+18
-8
@@ -8,35 +8,45 @@ import (
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/verify/trust"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var _ trust.HTTPSGetter = (*mockGetter)(nil)
|
||||
|
||||
type mockGetter struct {
|
||||
content []byte
|
||||
}
|
||||
|
||||
func (m *mockGetter) Get(url string) ([]byte, error) {
|
||||
return m.content, nil
|
||||
}
|
||||
|
||||
func TestNewCABundleCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
tempDir, err := os.MkdirTemp("", "ca-bundle-test")
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
manifestContent := []byte(`{"root_of_trust": {"product_line": "Milan"}}`)
|
||||
manifestPath := path.Join(tempDir, "manifest.json")
|
||||
err = os.WriteFile(manifestPath, manifestContent, 0o644)
|
||||
assert.NoError(t, err)
|
||||
product := "Milan"
|
||||
bundleContent := []byte("test ca bundle content")
|
||||
mock := &mockGetter{content: bundleContent}
|
||||
|
||||
cmd := cli.NewCABundleCmd(tempDir)
|
||||
cmd.SetArgs([]string{manifestPath})
|
||||
cmd := cli.NewCABundleCmd(tempDir, mock)
|
||||
cmd.SetArgs([]string{product})
|
||||
output := &bytes.Buffer{}
|
||||
cmd.SetOutput(output)
|
||||
err = cmd.Execute()
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedFilePath := path.Join(tempDir, "Milan", caBundleName)
|
||||
expectedFilePath := path.Join(tempDir, product, caBundleName)
|
||||
_, err = os.Stat(expectedFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(expectedFilePath)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, content)
|
||||
assert.Equal(t, bundleContent, content)
|
||||
}
|
||||
|
||||
func TestSaveToFile(t *testing.T) {
|
||||
|
||||
+13
-14
@@ -14,11 +14,6 @@ import (
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
var (
|
||||
ismanifest bool
|
||||
toBase64 bool
|
||||
)
|
||||
|
||||
func (cli *CLI) NewFileHashCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "checksum",
|
||||
@@ -28,29 +23,33 @@ func (cli *CLI) NewFileHashCmd() *cobra.Command {
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
path := args[0]
|
||||
|
||||
if ismanifest {
|
||||
if cli.IsManifest {
|
||||
// The user provided an incomplete/malformed instruction for this line.
|
||||
// Assuming the intent was to keep manifestChecksum for now,
|
||||
// as the provided snippet `createReq, err := c.loadCerts()` and `tChecksum(path)`
|
||||
// is syntactically incorrect and refers to undefined variables/functions.
|
||||
hash, err := manifestChecksum(path)
|
||||
if err != nil {
|
||||
printError(cmd, "Error computing hash: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error computing hash: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("Hash of manifest file:", hashOut(hash))
|
||||
cmd.Println("Hash of manifest file:", cli.hashOut(hash))
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := internal.ChecksumHex(path)
|
||||
if err != nil {
|
||||
printError(cmd, "Error computing hash: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error computing hash: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("Hash of file:", hashOut(hash))
|
||||
cmd.Println("Hash of file:", cli.hashOut(hash))
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&ismanifest, "manifest", "m", false, "Compute the hash of the manifest file")
|
||||
cmd.Flags().BoolVarP(&toBase64, "base64", "b", false, "Output the hash in base64")
|
||||
cmd.Flags().BoolVarP(&cli.IsManifest, "manifest", "m", false, "Compute the hash of the manifest file")
|
||||
cmd.Flags().BoolVarP(&cli.ToBase64, "base64", "b", false, "Output the hash in base64")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -77,8 +76,8 @@ func manifestChecksum(path string) (string, error) {
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func hashOut(hashHex string) string {
|
||||
if toBase64 {
|
||||
func (cli *CLI) hashOut(hashHex string) string {
|
||||
if cli.ToBase64 {
|
||||
return hexToBase64(hashHex)
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestManifestChecksum(t *testing.T) {
|
||||
"name": "Example Computation",
|
||||
"description": "This is an example computation"
|
||||
}`,
|
||||
expectedSum: "a99683e4d22ba54cefa51aa49fb2e97a92b828c088395992ddff16a6236f3299",
|
||||
expectedSum: "c8344428fca26ed8c4dfee031cf1459ebcf81bd6cb5f4318f72b3bbd68782146",
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
@@ -220,8 +220,8 @@ func TestHashOut(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
toBase64 = tc.toBase64
|
||||
out := hashOut(tc.hashHex)
|
||||
c := &CLI{ToBase64: tc.toBase64}
|
||||
out := c.hashOut(tc.hashHex)
|
||||
if out != tc.expectedOut {
|
||||
t.Errorf("Expected %s, got %s", tc.expectedOut, out)
|
||||
}
|
||||
|
||||
+9
-8
@@ -27,7 +27,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -37,16 +37,17 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
|
||||
f, err := os.Stat(datasetPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var dataset *os.File
|
||||
|
||||
if f.IsDir() {
|
||||
cmd.Println("Detected directory, zipping dataset...")
|
||||
dataset, err = internal.ZipDirectoryToTempFile(datasetPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error zipping dataset directory: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error zipping dataset directory: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer dataset.Close()
|
||||
@@ -54,7 +55,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
} else {
|
||||
dataset, err = os.Open(datasetPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer dataset.Close()
|
||||
@@ -62,7 +63,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
|
||||
privKeyFile, err := os.ReadFile(args[1])
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -70,13 +71,13 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
|
||||
privKey, err := decodeKey(pemBlock)
|
||||
if err != nil {
|
||||
printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
||||
if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataset, path.Base(datasetPath), privKey); err != nil {
|
||||
printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err)
|
||||
cli.printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -88,7 +89,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func decodeKey(b *pem.Block) (interface{}, error) {
|
||||
func decodeKey(b *pem.Block) (any, error) {
|
||||
if b == nil {
|
||||
return nil, errors.New("error decoding key")
|
||||
}
|
||||
|
||||
+2
-2
@@ -40,8 +40,8 @@ func decodeErros(err error) error {
|
||||
}
|
||||
}
|
||||
|
||||
func printError(cmd *cobra.Command, message string, err error) {
|
||||
if !Verbose {
|
||||
func (c *CLI) printError(cmd *cobra.Command, message string, err error) {
|
||||
if !c.Verbose {
|
||||
err = decodeErros(err)
|
||||
}
|
||||
msg := color.New(color.FgRed).Sprintf(message, err)
|
||||
|
||||
+2
-2
@@ -95,12 +95,12 @@ func TestPrintError(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
Verbose = tt.verbose
|
||||
c := &CLI{Verbose: tt.verbose}
|
||||
cmd := &cobra.Command{}
|
||||
buf := new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
|
||||
printError(cmd, tt.message, tt.err)
|
||||
c.printError(cmd, tt.message, tt.err)
|
||||
|
||||
if got := buf.String(); got != tt.expected {
|
||||
t.Errorf("printError() output = %q, want %q", got, tt.expected)
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
const (
|
||||
imaMeasurementsFilename = "ima_measurements"
|
||||
)
|
||||
|
||||
func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "ima-measurements",
|
||||
Short: "Retrieve Linux IMA measurements file",
|
||||
Example: "ima-measurements <optional_file_name>",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("⏳ Retrieving computation Linux IMA measurements file")
|
||||
|
||||
filename := imaMeasurementsFilename
|
||||
if len(args) >= 1 {
|
||||
filename = args[0]
|
||||
}
|
||||
|
||||
imaMeasurementsFile, err := os.Create(filename)
|
||||
if err != nil {
|
||||
cli.printError(cmd, "Error creating imaMeasurements file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer imaMeasurementsFile.Close()
|
||||
|
||||
pcr10, err := cli.agentSDK.IMAMeasurements(cmd.Context(), imaMeasurementsFile)
|
||||
if err != nil {
|
||||
cli.printError(cmd, "Error retrieving Linux IMA measurements file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println(color.New(color.FgGreen).Sprintf("Linux IMA measurements file retrieved and saved successfully as %s! PCR10 = %s ✔ ", filename, hex.EncodeToString(pcr10)))
|
||||
|
||||
calculatedPCR10 := make([]byte, vtpm.Hash1)
|
||||
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
cli.printError(cmd, "Failed to open file: %v ❌ ", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
parts := strings.Fields(line)
|
||||
|
||||
if parts[0] != "10" {
|
||||
continue
|
||||
}
|
||||
|
||||
digestHex := parts[1]
|
||||
if digestHex == strings.Repeat("0", 40) {
|
||||
digestHex = strings.Repeat("f", 40)
|
||||
}
|
||||
|
||||
digest, err := hex.DecodeString(digestHex)
|
||||
if err != nil {
|
||||
cli.printError(cmd, "Failed to decode digest: %v ❌ ", err)
|
||||
continue
|
||||
}
|
||||
|
||||
hasher := sha1.New()
|
||||
hasher.Write(calculatedPCR10)
|
||||
hasher.Write(digest)
|
||||
calculatedPCR10 = hasher.Sum(nil)
|
||||
}
|
||||
|
||||
if hex.EncodeToString(pcr10) != hex.EncodeToString(calculatedPCR10) {
|
||||
cli.printError(cmd, "Measurements file not verified ❌ ", err)
|
||||
} else {
|
||||
cmd.Println(color.New(color.FgGreen).Sprintf("Measurements file verified!"))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
||||
)
|
||||
|
||||
func TestCLI_NewIMAMeasurementsCmd(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
connectErr error
|
||||
mockIMAData string
|
||||
mockError error
|
||||
expectedFilename string
|
||||
expectedOutput []string
|
||||
expectedError []string
|
||||
shouldCreateFile bool
|
||||
fileCreationError bool
|
||||
invalidDigestData bool
|
||||
setupCustomFile func(filename string) error
|
||||
}{
|
||||
{
|
||||
name: "successful_retrieval_default_filename",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedFilename: imaMeasurementsFilename,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "PCR10 = 0000000000000000000000000000000000000000", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "successful_retrieval_custom_filename",
|
||||
args: []string{"custom_ima_file.txt"},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedFilename: "custom_ima_file.txt",
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "custom_ima_file.txt", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "connection_error",
|
||||
args: []string{},
|
||||
connectErr: fmt.Errorf("connection failed"),
|
||||
expectedError: []string{"Failed to connect to agent: connection failed ❌"},
|
||||
},
|
||||
{
|
||||
name: "file_creation_error",
|
||||
args: []string{"/invalid/path/file.txt"},
|
||||
connectErr: nil,
|
||||
fileCreationError: true,
|
||||
expectedError: []string{"Error creating imaMeasurements file:"},
|
||||
},
|
||||
{
|
||||
name: "sdk_error",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockError: fmt.Errorf("SDK communication failed"),
|
||||
expectedError: []string{"Error retrieving Linux IMA measurements file: SDK communication failed ❌"},
|
||||
},
|
||||
{
|
||||
name: "verification_failure_wrong_pcr",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "10 9999999999999999999999999999999999999999 ima-ng sha1:0000000000000000000000000000000000000000 /usr/bin/test",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully"},
|
||||
expectedError: []string{"Measurements file not verified ❌"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "empty_measurements_file",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "measurements_with_non_pcr10_entries",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "measurements_with_zero_digest_replacement",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
|
||||
cli := &CLI{
|
||||
agentSDK: mockSDK,
|
||||
connectErr: tc.connectErr,
|
||||
}
|
||||
|
||||
if tc.connectErr == nil && !tc.fileCreationError {
|
||||
mockSDK.On("IMAMeasurements", mock.Anything, mock.Anything).Return([]byte(tc.mockIMAData), tc.mockError)
|
||||
}
|
||||
|
||||
cmd := cli.NewIMAMeasurementsCmd()
|
||||
|
||||
var output bytes.Buffer
|
||||
cmd.SetOut(&output)
|
||||
cmd.SetErr(&output)
|
||||
|
||||
expectedFilename := tc.expectedFilename
|
||||
if expectedFilename == "" {
|
||||
if len(tc.args) > 0 {
|
||||
expectedFilename = tc.args[0]
|
||||
} else {
|
||||
expectedFilename = imaMeasurementsFilename
|
||||
}
|
||||
}
|
||||
|
||||
if tc.setupCustomFile != nil {
|
||||
err := tc.setupCustomFile(expectedFilename)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cmd.SetArgs(tc.args)
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err, "Command execution failed")
|
||||
|
||||
outputStr := output.String()
|
||||
|
||||
for _, expectedMsg := range tc.expectedOutput {
|
||||
assert.Contains(t, outputStr, expectedMsg, "Expected output message not found")
|
||||
}
|
||||
|
||||
for _, expectedErr := range tc.expectedError {
|
||||
assert.Contains(t, outputStr, expectedErr, "Expected error message not found")
|
||||
}
|
||||
|
||||
if tc.shouldCreateFile && tc.connectErr == nil && !tc.fileCreationError && tc.mockError == nil {
|
||||
if _, err := os.Stat(expectedFilename); err == nil {
|
||||
os.Remove(expectedFilename)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.connectErr == nil && !tc.fileCreationError {
|
||||
mockSDK.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+12
-14
@@ -27,8 +27,6 @@ const (
|
||||
ED25519 = "ed25519"
|
||||
)
|
||||
|
||||
var KeyType string
|
||||
|
||||
func (cli *CLI) NewKeysCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "keys",
|
||||
@@ -38,65 +36,65 @@ func (cli *CLI) NewKeysCmd() *cobra.Command {
|
||||
Example: "./build/cocos-cli keys -k rsa",
|
||||
Args: cobra.ExactArgs(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
switch KeyType {
|
||||
switch cli.KeyType {
|
||||
case ECDSA:
|
||||
privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privEcdsaKey.PublicKey)
|
||||
if err != nil {
|
||||
printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := generateAndWriteKeys(privEcdsaKey, pubKeyBytes, ecdsaKeyType); err != nil {
|
||||
printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
case ED25519:
|
||||
pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
pubKey, err := x509.MarshalPKIXPublicKey(pubEd25519Key)
|
||||
if err != nil {
|
||||
printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
if err := generateAndWriteKeys(privEd25519Key, pubKey, ed25519KeyType); err != nil {
|
||||
printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, keyBitSize)
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
|
||||
if err != nil {
|
||||
printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
if err := generateAndWriteKeys(privKey, pubKeyBytes, rsaKeyType); err != nil {
|
||||
printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Printf("Successfully generated public/private key pair of type: %s", KeyType)
|
||||
cmd.Printf("Successfully generated public/private key pair of type: %s", cli.KeyType)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func generateAndWriteKeys(privKey interface{}, pubKeyBytes []byte, keyType string) error {
|
||||
func generateAndWriteKeys(privKey any, pubKeyBytes []byte, keyType string) error {
|
||||
privFile, err := os.Create(privateKeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
+3
-3
@@ -37,8 +37,8 @@ func TestGenerateAndWriteKeys(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
KeyType = tt.keyType
|
||||
cmd := (&CLI{}).NewKeysCmd()
|
||||
c := &CLI{KeyType: tt.keyType}
|
||||
cmd := c.NewKeysCmd()
|
||||
cmd.Run(cmd, []string{})
|
||||
|
||||
if _, err := os.Stat(privateKeyFile); os.IsNotExist(err) {
|
||||
@@ -57,7 +57,7 @@ func TestGenerateAndWriteKeys(t *testing.T) {
|
||||
t.Fatalf("Failed to decode private key PEM")
|
||||
}
|
||||
|
||||
var privKey interface{}
|
||||
var privKey any
|
||||
switch tt.keyType {
|
||||
case "rsa":
|
||||
privKey, err = x509.ParsePKCS1PrivateKey(privPem.Bytes)
|
||||
|
||||
+47
-39
@@ -4,7 +4,6 @@ package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -21,16 +20,6 @@ const (
|
||||
ttlFlag = "ttl"
|
||||
)
|
||||
|
||||
var (
|
||||
agentCVMServerUrl string
|
||||
agentCVMServerCA string
|
||||
agentCVMClientKey string
|
||||
agentCVMClientCrt string
|
||||
agentCVMCaUrl string
|
||||
agentLogLevel string
|
||||
ttl time.Duration
|
||||
)
|
||||
|
||||
func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "create-vm",
|
||||
@@ -38,31 +27,42 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
Example: `create-vm`,
|
||||
Args: cobra.ExactArgs(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
if c.connectErr != nil {
|
||||
c.printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
}
|
||||
if c.managerClient == nil {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
c.printError(cmd, "Failed to connect to manager: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
createReq, err := loadCerts()
|
||||
createReq, err := c.loadCerts()
|
||||
if err != nil {
|
||||
printError(cmd, "Error loading certs: %v ❌ ", err)
|
||||
c.printError(cmd, "Error loading certs: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
createReq.AgentCvmServerUrl = agentCVMServerUrl
|
||||
createReq.AgentLogLevel = agentLogLevel
|
||||
createReq.AgentCvmCaUrl = agentCVMCaUrl
|
||||
createReq.AgentCvmServerUrl = c.AgentVM.CVMServerURL
|
||||
createReq.AgentLogLevel = c.AgentVM.LogLevel
|
||||
createReq.AgentCvmCaUrl = c.AgentVM.CVMCaURL
|
||||
createReq.AwsAccessKeyId = c.AWS.AccessKeyID
|
||||
createReq.AwsSecretAccessKey = c.AWS.SecretAccessKey
|
||||
createReq.AwsEndpointUrl = c.AWS.EndpointURL
|
||||
createReq.AwsRegion = c.AWS.Region
|
||||
createReq.AaKbsParams = c.Attestation.KbsParams
|
||||
|
||||
if ttl > 0 {
|
||||
createReq.Ttl = ttl.String()
|
||||
if c.AgentVM.Ttl > 0 {
|
||||
createReq.Ttl = c.AgentVM.Ttl.String()
|
||||
}
|
||||
|
||||
cmd.Println("🔗 Creating a new virtual machine")
|
||||
|
||||
res, err := c.managerClient.CreateVm(cmd.Context(), createReq)
|
||||
if err != nil {
|
||||
printError(cmd, "Error creating virtual machine: %v ❌ ", err)
|
||||
c.printError(cmd, "Error creating virtual machine: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -70,15 +70,20 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&agentCVMServerUrl, serverURL, "", "CVM server URL")
|
||||
cmd.Flags().StringVar(&agentCVMServerCA, serverCA, "", "CVM server CA")
|
||||
cmd.Flags().StringVar(&agentCVMClientKey, clientKey, "", "CVM client key")
|
||||
cmd.Flags().StringVar(&agentCVMClientCrt, clientCrt, "", "CVM client crt")
|
||||
cmd.Flags().StringVar(&agentCVMCaUrl, agentCVMCaUrl, "", "CVM CA service URL")
|
||||
cmd.Flags().StringVar(&agentLogLevel, logLevel, "", "Agent Log level")
|
||||
cmd.Flags().DurationVar(&ttl, ttlFlag, 0, "TTL for the VM")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMServerURL, serverURL, "", "CVM server URL")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMServerCA, serverCA, "", "CVM server CA")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMClientKey, clientKey, "", "CVM client key")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMClientCrt, clientCrt, "", "CVM client crt")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMCaURL, caUrl, "", "CVM CA service URL")
|
||||
cmd.Flags().StringVar(&c.AgentVM.LogLevel, logLevel, "", "Agent Log level")
|
||||
cmd.Flags().DurationVar(&c.AgentVM.Ttl, ttlFlag, 0, "TTL for the VM")
|
||||
cmd.Flags().StringVar(&c.AWS.AccessKeyID, "aws-access-key-id", "", "AWS Access Key ID for S3/MinIO")
|
||||
cmd.Flags().StringVar(&c.AWS.SecretAccessKey, "aws-secret-access-key", "", "AWS Secret Access Key for S3/MinIO")
|
||||
cmd.Flags().StringVar(&c.AWS.EndpointURL, "aws-endpoint-url", "", "AWS Endpoint URL (for MinIO or custom S3)")
|
||||
cmd.Flags().StringVar(&c.AWS.Region, "aws-region", "", "AWS Region")
|
||||
cmd.Flags().StringVar(&c.Attestation.KbsParams, "aa-kbs-params", "", "Attestation Agent KBS Parameters (e.g. protocol=http,type=kbs,url=http://... or just type=sample)")
|
||||
if err := cmd.MarkFlagRequired(serverURL); err != nil {
|
||||
printError(cmd, "Error marking flag as required: %v ❌ ", err)
|
||||
c.printError(cmd, "Error marking flag as required: %v ❌ ", err)
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -92,20 +97,23 @@ func (c *CLI) NewRemoveVMCmd() *cobra.Command {
|
||||
Example: `remove-vm <cvm_id>`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := c.InitializeManagerClient(cmd); err == nil {
|
||||
defer c.Close()
|
||||
}
|
||||
|
||||
if c.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
c.printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
}
|
||||
if c.managerClient == nil {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
c.printError(cmd, "Failed to connect to manager: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
cmd.Println("🔗 Removing virtual machine")
|
||||
|
||||
_, err := c.managerClient.RemoveVm(cmd.Context(), &manager.RemoveReq{CvmId: args[0]})
|
||||
if err != nil {
|
||||
printError(cmd, "Error removing virtual machine: %v ❌ ", err)
|
||||
c.printError(cmd, "Error removing virtual machine: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -122,18 +130,18 @@ func fileReader(path string) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
func loadCerts() (*manager.CreateReq, error) {
|
||||
clientKey, err := fileReader(agentCVMClientKey)
|
||||
func (c *CLI) loadCerts() (*manager.CreateReq, error) {
|
||||
clientKey, err := fileReader(c.AgentVM.CVMClientKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientCrt, err := fileReader(agentCVMClientCrt)
|
||||
clientCrt, err := fileReader(c.AgentVM.CVMClientCrt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverCA, err := fileReader(agentCVMServerCA)
|
||||
serverCA, err := fileReader(c.AgentVM.CVMServerCA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -0,0 +1,588 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/mocks"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
func TestCLI_NewCreateVMCmd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*mocks.ManagerServiceClient)
|
||||
setupCLI func(*CLI)
|
||||
setupFiles func(string) error
|
||||
flags map[string]string
|
||||
expectedOutput string
|
||||
expectedError string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful VM creation with all flags",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.MatchedBy(func(req *manager.CreateReq) bool {
|
||||
return req.AgentCvmServerUrl == "https://server.com" &&
|
||||
req.AgentLogLevel == "debug" &&
|
||||
req.AgentCvmCaUrl == "https://ca.com" &&
|
||||
req.Ttl == "1h0m0s" &&
|
||||
string(req.AgentCvmServerCaCert) == "ca-cert-content" &&
|
||||
string(req.AgentCvmClientKey) == "client-key-content" &&
|
||||
string(req.AgentCvmClientCert) == "client-cert-content"
|
||||
})).Return(&manager.CreateRes{
|
||||
CvmId: "vm-123",
|
||||
ForwardedPort: "8080",
|
||||
}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"server-ca.pem": "ca-cert-content",
|
||||
"client-key.pem": "client-key-content",
|
||||
"client-crt.pem": "client-cert-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
"server-ca": "server-ca.pem",
|
||||
"client-key": "client-key.pem",
|
||||
"client-crt": "client-crt.pem",
|
||||
"ca-url": "https://ca.com",
|
||||
"log-level": "debug",
|
||||
"ttl": "1h",
|
||||
},
|
||||
expectedOutput: "✅ Virtual machine created successfully with id vm-123 and port 8080",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "successful VM creation with minimal flags",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.MatchedBy(func(req *manager.CreateReq) bool {
|
||||
return req.AgentCvmServerUrl == "https://server.com" &&
|
||||
req.AgentLogLevel == "" &&
|
||||
req.AgentCvmCaUrl == "" &&
|
||||
req.Ttl == "" &&
|
||||
len(req.AgentCvmServerCaCert) == 0 &&
|
||||
len(req.AgentCvmClientKey) == 0 &&
|
||||
len(req.AgentCvmClientCert) == 0
|
||||
})).Return(&manager.CreateRes{
|
||||
CvmId: "vm-456",
|
||||
ForwardedPort: "9090",
|
||||
}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // No files needed for minimal test
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedOutput: "✅ Virtual machine created successfully with id vm-456 and port 9090",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "manager client initialization failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as initialization fails before calling any methods
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
cli.connectErr = errors.New("connection failed")
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedError: "Failed to connect to manager: connection failed ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "certificate loading failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as cert loading fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // Don't create the cert file
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
"server-ca": "nonexistent-ca.pem",
|
||||
},
|
||||
expectedError: "Error loading certs:",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "CreateVm API call failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.Anything).Return(nil, errors.New("API error"))
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedError: "Error creating virtual machine: API error ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing required server-url flag",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{}, // No server-url flag
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "cli-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
oldDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
err = os.Chdir(tmpDir)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := os.Chdir(oldDir)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
err = tt.setupFiles(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockClient := new(mocks.ManagerServiceClient)
|
||||
tt.setupMock(mockClient)
|
||||
|
||||
mockCLI := &CLI{
|
||||
managerClient: mockClient,
|
||||
}
|
||||
|
||||
tt.setupCLI(mockCLI)
|
||||
|
||||
cmd := mockCLI.NewCreateVMCmd()
|
||||
|
||||
for flag, value := range tt.flags {
|
||||
err := cmd.Flags().Set(flag, value)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
|
||||
if tt.expectError {
|
||||
if tt.expectedError != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedError)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedOutput != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedOutput)
|
||||
}
|
||||
}
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLI_NewRemoveVMCmd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*mocks.ManagerServiceClient)
|
||||
setupCLI func(*CLI)
|
||||
args []string
|
||||
expectedOutput string
|
||||
expectedError string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful VM removal",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("RemoveVm", mock.Anything, &manager.RemoveReq{
|
||||
CvmId: "vm-123",
|
||||
}).Return(&emptypb.Empty{}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-123"},
|
||||
expectedOutput: "✅ Virtual machine removed successfully",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "manager client initialization failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as initialization fails before calling any methods
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
cli.connectErr = errors.New("connection failed")
|
||||
},
|
||||
args: []string{"vm-123"},
|
||||
expectedError: "Failed to connect to manager: connection failed ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "RemoveVm API call failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("RemoveVm", mock.Anything, &manager.RemoveReq{
|
||||
CvmId: "vm-456",
|
||||
}).Return(nil, errors.New("removal failed"))
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-456"},
|
||||
expectedError: "Error removing virtual machine: removal failed ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing VM ID argument",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{}, // No VM ID provided
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "too many arguments",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-123", "extra-arg"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockClient := new(mocks.ManagerServiceClient)
|
||||
tt.setupMock(mockClient)
|
||||
|
||||
mockCLI := &CLI{
|
||||
managerClient: mockClient,
|
||||
}
|
||||
tt.setupCLI(mockCLI)
|
||||
|
||||
cmd := mockCLI.NewRemoveVMCmd()
|
||||
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
|
||||
if tt.expectError {
|
||||
if tt.expectedError != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedError)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedOutput != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedOutput)
|
||||
}
|
||||
}
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFile func(string) (string, error)
|
||||
path string
|
||||
expectedResult []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful file read",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
filePath := filepath.Join(tmpDir, "test.txt")
|
||||
err := os.WriteFile(filePath, []byte("test content"), 0o644)
|
||||
return filePath, err
|
||||
},
|
||||
expectedResult: []byte("test content"),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty path returns nil",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
path: "",
|
||||
expectedResult: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent file returns error",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
return filepath.Join(tmpDir, "nonexistent.txt"), nil
|
||||
},
|
||||
expectedResult: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "fileReader-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
filePath, err := tt.setupFile(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.path != "" {
|
||||
filePath = tt.path
|
||||
}
|
||||
|
||||
result, err := fileReader(filePath)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCerts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFiles func(string) error
|
||||
setupCLI func(string, *CLI)
|
||||
expectError bool
|
||||
validate func(*testing.T, *manager.CreateReq)
|
||||
}{
|
||||
{
|
||||
name: "successful cert loading with all files",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"client.key": "client-key-content",
|
||||
"client.crt": "client-cert-content",
|
||||
"server.ca": "server-ca-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
c.AgentVM.CVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
c.AgentVM.CVMServerCA = filepath.Join(tmpDir, "server.ca")
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, req *manager.CreateReq) {
|
||||
assert.Equal(t, []byte("client-key-content"), req.AgentCvmClientKey)
|
||||
assert.Equal(t, []byte("client-cert-content"), req.AgentCvmClientCert)
|
||||
assert.Equal(t, []byte("server-ca-content"), req.AgentCvmServerCaCert)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful cert loading with empty paths",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = ""
|
||||
c.AgentVM.CVMClientCrt = ""
|
||||
c.AgentVM.CVMServerCA = ""
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, req *manager.CreateReq) {
|
||||
assert.Nil(t, req.AgentCvmClientKey)
|
||||
assert.Nil(t, req.AgentCvmClientCert)
|
||||
assert.Nil(t, req.AgentCvmServerCaCert)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "client key file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // Don't create client key file
|
||||
},
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = filepath.Join(tmpDir, "nonexistent.key")
|
||||
c.AgentVM.CVMClientCrt = ""
|
||||
c.AgentVM.CVMServerCA = ""
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "client cert file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
// Create client key but not cert
|
||||
return os.WriteFile(filepath.Join(tmpDir, "client.key"), []byte("key-content"), 0o644)
|
||||
},
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
c.AgentVM.CVMClientCrt = filepath.Join(tmpDir, "nonexistent.crt")
|
||||
c.AgentVM.CVMServerCA = ""
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "server CA file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"client.key": "client-key-content",
|
||||
"client.crt": "client-cert-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
c.AgentVM.CVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
c.AgentVM.CVMServerCA = filepath.Join(tmpDir, "nonexistent.ca")
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "loadCerts-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = tt.setupFiles(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
c := &CLI{}
|
||||
tt.setupCLI(tmpDir, c)
|
||||
|
||||
result, err := c.loadCerts()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandCreation(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
|
||||
t.Run("create-vm command creation", func(t *testing.T) {
|
||||
cmd := cli.NewCreateVMCmd()
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "create-vm", cmd.Use)
|
||||
assert.Equal(t, "Create a new virtual machine", cmd.Short)
|
||||
|
||||
// Check that required flags are set
|
||||
flag := cmd.Flags().Lookup("server-url")
|
||||
assert.NotNil(t, flag)
|
||||
// Note: We can't easily test MarkFlagRequired in unit tests
|
||||
})
|
||||
|
||||
t.Run("remove-vm command creation", func(t *testing.T) {
|
||||
cmd := cli.NewRemoveVMCmd()
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "remove-vm", cmd.Use)
|
||||
assert.Equal(t, "Remove a virtual machine", cmd.Short)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTTLHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ttlInput string
|
||||
expectedTTL time.Duration
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid duration",
|
||||
ttlInput: "1h30m",
|
||||
expectedTTL: time.Hour + 30*time.Minute,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "zero duration",
|
||||
ttlInput: "0",
|
||||
expectedTTL: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
ttlInput: "",
|
||||
expectedTTL: 0,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockCLI := &CLI{
|
||||
managerClient: new(mocks.ManagerServiceClient),
|
||||
}
|
||||
|
||||
cmd := mockCLI.NewCreateVMCmd()
|
||||
|
||||
if tt.ttlInput != "" {
|
||||
err := cmd.Flags().Set("ttl", tt.ttlInput)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedTTL, mockCLI.AgentVM.Ttl)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user