mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-584 - Support multiple kbs (#587)
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* feat: Implement per-resource KBS configuration, allowing algorithms and datasets to specify individual KBS URLs. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Encapsulate CLI error handling and CVM certificate paths within the CLI struct, and add algorithm type to agent's algorithm structure. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * style: Remove blank lines and fix indentation in CLI commands. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Update downloadAndDecryptGenericResource to accept KBS URL as a parameter and adjust related tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: group CLI configuration into structured types and simplify skopeo decryption key handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c59a413765
commit
d5badba547
@@ -281,17 +281,22 @@ go build -o build/cvms-test ./test/cvms/main.go
|
||||
HOST=$HOST_IP PORT=7001 ./build/cvms-test \
|
||||
-public-key-path ./public.pem \
|
||||
-attested-tls-bool false \
|
||||
-kbs-url http://$HOST_IP:8080 \
|
||||
-algo-type python \
|
||||
-algo-source-url docker://$HOST_IP:5000/encrypted-lin-reg:v1.0 \
|
||||
-algo-kbs-path default/key/algo-key \
|
||||
-algo-kbs-url http://$HOST_IP:8080 \
|
||||
-algo-hash $ALGO_HASH \
|
||||
-algo-args datasets/dataset_0.csv \
|
||||
-dataset-source-urls docker://$HOST_IP:5000/encrypted-iris:v1.0 \
|
||||
-dataset-kbs-paths default/key/dataset-key \
|
||||
-dataset-kbs-urls http://$HOST_IP:8080 \
|
||||
-dataset-hash $DATASET_HASH
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> You must specify the KBS URL for each encrypted resource using `-algo-kbs-url` and `-dataset-kbs-urls`. A global KBS is no longer supported.
|
||||
|
||||
|
||||
### 3. Create VM via CLI (Host)
|
||||
|
||||
```bash
|
||||
@@ -357,17 +362,31 @@ The CVMS server sends this manifest to the agent:
|
||||
"type": "oci-image",
|
||||
"uri": "docker://localhost:5000/encrypted-lin-reg:v1.0",
|
||||
"encrypted": true,
|
||||
"kbs_resource_path": "default/key/algo-key"
|
||||
"kbs_resource_path": "default/key/algo-key",
|
||||
"kbs": {
|
||||
"url": "http://192.168.100.15:8080",
|
||||
"enabled": true
|
||||
}
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"type": "oci-image",
|
||||
"uri": "docker://localhost:5000/encrypted-iris:v1.0",
|
||||
"encrypted": true,
|
||||
"kbs_resource_path": "default/key/dataset-key"
|
||||
"filename": "iris.csv",
|
||||
"source": {
|
||||
"type": "oci-image",
|
||||
"url": "docker://localhost:5000/encrypted-iris:v1.0",
|
||||
"encrypted": true,
|
||||
"kbs_resource_path": "default/key/dataset-key"
|
||||
},
|
||||
"kbs": {
|
||||
"url": "http://192.168.100.20:8080",
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
],
|
||||
"kbs_url": "http://192.168.100.15:8080"
|
||||
"kbs": {
|
||||
"url": "http://192.168.100.15:8080",
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
manifest := agent.Computation{
|
||||
ResultConsumers: []agent.ResultConsumer{{UserKey: resultConsumerPubKey}},
|
||||
Datasets: []agent.Dataset{{UserKey: dataProviderPubKey}},
|
||||
Algorithm: agent.Algorithm{UserKey: algorithmProviderPubKey},
|
||||
Algorithm: &agent.Algorithm{UserKey: algorithmProviderPubKey},
|
||||
}
|
||||
|
||||
auth, err := New(manifest)
|
||||
|
||||
@@ -52,9 +52,8 @@ type Computation struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Datasets Datasets `json:"datasets,omitempty"`
|
||||
Algorithm Algorithm `json:"algorithm,omitempty"`
|
||||
Algorithm *Algorithm `json:"algorithm,omitempty"`
|
||||
ResultConsumers []ResultConsumer `json:"result_consumers,omitempty"`
|
||||
KBS KBSConfig `json:"kbs,omitempty"`
|
||||
}
|
||||
|
||||
type ResultConsumer struct {
|
||||
@@ -76,6 +75,7 @@ type Dataset struct {
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Source *ResourceSource `json:"source,omitempty"` // Optional remote source
|
||||
Decompress bool `json:"decompress,omitempty"`
|
||||
KBS *KBSConfig `json:"kbs,omitempty"`
|
||||
}
|
||||
|
||||
type Datasets []Dataset
|
||||
@@ -88,6 +88,7 @@ type Algorithm struct {
|
||||
Source *ResourceSource `json:"source,omitempty"` // Optional remote source
|
||||
AlgoType string `json:"algo_type,omitempty"`
|
||||
AlgoArgs []string `json:"algo_args,omitempty"`
|
||||
KBS *KBSConfig `json:"kbs,omitempty"`
|
||||
}
|
||||
|
||||
type ManifestIndexKey struct{}
|
||||
|
||||
@@ -234,9 +234,10 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
}
|
||||
|
||||
if runReq.Algorithm != nil {
|
||||
ac.Algorithm = agent.Algorithm{
|
||||
Hash: [32]byte(runReq.Algorithm.Hash),
|
||||
UserKey: runReq.Algorithm.UserKey,
|
||||
ac.Algorithm = &agent.Algorithm{
|
||||
Hash: [32]byte(runReq.Algorithm.Hash),
|
||||
UserKey: runReq.Algorithm.UserKey,
|
||||
AlgoType: runReq.Algorithm.AlgoType,
|
||||
}
|
||||
// Copy remote source if configured
|
||||
if runReq.Algorithm.Source != nil {
|
||||
@@ -246,8 +247,13 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
Encrypted: runReq.Algorithm.Source.Encrypted,
|
||||
}
|
||||
}
|
||||
ac.Algorithm.AlgoType = runReq.Algorithm.AlgoType
|
||||
ac.Algorithm.AlgoArgs = runReq.Algorithm.AlgoArgs
|
||||
if runReq.Algorithm.Kbs != nil {
|
||||
ac.Algorithm.KBS = &agent.KBSConfig{
|
||||
URL: runReq.Algorithm.Kbs.Url,
|
||||
Enabled: runReq.Algorithm.Kbs.Enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ds := range runReq.Datasets {
|
||||
@@ -265,6 +271,12 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
}
|
||||
}
|
||||
dataset.Decompress = ds.Decompress
|
||||
if ds.Kbs != nil {
|
||||
dataset.KBS = &agent.KBSConfig{
|
||||
URL: ds.Kbs.Url,
|
||||
Enabled: ds.Kbs.Enabled,
|
||||
}
|
||||
}
|
||||
ac.Datasets = append(ac.Datasets, dataset)
|
||||
}
|
||||
|
||||
@@ -274,14 +286,6 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
})
|
||||
}
|
||||
|
||||
// Copy KBS configuration
|
||||
if runReq.Kbs != nil {
|
||||
ac.KBS = agent.KBSConfig{
|
||||
URL: runReq.Kbs.Url,
|
||||
Enabled: runReq.Kbs.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the agent is in the correct state to initialize a new computation.
|
||||
// If the agent is already processing this computation (e.g., after a reconnection),
|
||||
// skip initialization to avoid state errors.
|
||||
|
||||
@@ -554,11 +554,12 @@ func TestManagerClient_handleRunReqChunksWithRemoteSource(t *testing.T) {
|
||||
KbsResourcePath: "default/key/algo-key",
|
||||
Encrypted: true,
|
||||
},
|
||||
Kbs: &cvms.KBSConfig{
|
||||
Url: "https://kbs.example.com:8080",
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Kbs: &cvms.KBSConfig{
|
||||
Url: "https://kbs.example.com:8080",
|
||||
Enabled: true,
|
||||
},
|
||||
|
||||
ResultConsumers: []*cvms.ResultConsumer{
|
||||
{
|
||||
UserKey: []byte("test-consumer"),
|
||||
@@ -577,8 +578,8 @@ func TestManagerClient_handleRunReqChunksWithRemoteSource(t *testing.T) {
|
||||
|
||||
mockSvc.On("State").Return("ReceivingManifest")
|
||||
mockSvc.On("InitComputation", mock.Anything, mock.MatchedBy(func(c agent.Computation) bool {
|
||||
// Verify KBS config is passed
|
||||
if !c.KBS.Enabled || c.KBS.URL != "https://kbs.example.com:8080" {
|
||||
// Verify Algorithm KBS config is passed
|
||||
if c.Algorithm.KBS == nil || !c.Algorithm.KBS.Enabled || c.Algorithm.KBS.URL != "https://kbs.example.com:8080" {
|
||||
return false
|
||||
}
|
||||
// Verify algorithm source is passed
|
||||
|
||||
+33
-23
@@ -826,7 +826,6 @@ type ComputationRunReq struct {
|
||||
Algorithm *Algorithm `protobuf:"bytes,5,opt,name=algorithm,proto3" json:"algorithm,omitempty"`
|
||||
ResultConsumers []*ResultConsumer `protobuf:"bytes,6,rep,name=result_consumers,json=resultConsumers,proto3" json:"result_consumers,omitempty"`
|
||||
AgentConfig *AgentConfig `protobuf:"bytes,7,opt,name=agent_config,json=agentConfig,proto3" json:"agent_config,omitempty"`
|
||||
Kbs *KBSConfig `protobuf:"bytes,8,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration for remote resources
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -910,13 +909,6 @@ func (x *ComputationRunReq) GetAgentConfig() *AgentConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *ComputationRunReq) GetKbs() *KBSConfig {
|
||||
if x != nil {
|
||||
return x.Kbs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ResultConsumer struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
UserKey []byte `protobuf:"bytes,1,opt,name=userKey,proto3" json:"userKey,omitempty"`
|
||||
@@ -968,6 +960,7 @@ type Dataset struct {
|
||||
Filename string `protobuf:"bytes,3,opt,name=filename,proto3" json:"filename,omitempty"`
|
||||
Source *Source `protobuf:"bytes,4,opt,name=source,proto3" json:"source,omitempty"` // Optional remote source for encrypted dataset
|
||||
Decompress bool `protobuf:"varint,5,opt,name=decompress,proto3" json:"decompress,omitempty"`
|
||||
Kbs *KBSConfig `protobuf:"bytes,6,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration override
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1037,6 +1030,13 @@ func (x *Dataset) GetDecompress() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *Dataset) GetKbs() *KBSConfig {
|
||||
if x != nil {
|
||||
return x.Kbs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Algorithm struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` // should be sha3.Sum256, 32 byte length.
|
||||
@@ -1044,6 +1044,7 @@ type Algorithm struct {
|
||||
Source *Source `protobuf:"bytes,3,opt,name=source,proto3" json:"source,omitempty"` // Optional remote source for encrypted algorithm
|
||||
AlgoType string `protobuf:"bytes,4,opt,name=algo_type,json=algoType,proto3" json:"algo_type,omitempty"`
|
||||
AlgoArgs []string `protobuf:"bytes,5,rep,name=algo_args,json=algoArgs,proto3" json:"algo_args,omitempty"`
|
||||
Kbs *KBSConfig `protobuf:"bytes,6,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration override
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1113,6 +1114,13 @@ func (x *Algorithm) GetAlgoArgs() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Algorithm) GetKbs() *KBSConfig {
|
||||
if x != nil {
|
||||
return x.Kbs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Source struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` // Type of source: "oci-image" (only OCI images supported for CoCo)
|
||||
@@ -1485,7 +1493,7 @@ const file_agent_cvms_cvms_proto_rawDesc = "" +
|
||||
"\fRunReqChunks\x12\x12\n" +
|
||||
"\x04data\x18\x01 \x01(\fR\x04data\x12\x0e\n" +
|
||||
"\x02id\x18\x02 \x01(\tR\x02id\x12\x17\n" +
|
||||
"\ais_last\x18\x03 \x01(\bR\x06isLast\"\xcd\x02\n" +
|
||||
"\ais_last\x18\x03 \x01(\bR\x06isLast\"\xaa\x02\n" +
|
||||
"\x11ComputationRunReq\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" +
|
||||
"\x04name\x18\x02 \x01(\tR\x04name\x12 \n" +
|
||||
@@ -1493,10 +1501,9 @@ const file_agent_cvms_cvms_proto_rawDesc = "" +
|
||||
"\bdatasets\x18\x04 \x03(\v2\r.cvms.DatasetR\bdatasets\x12-\n" +
|
||||
"\talgorithm\x18\x05 \x01(\v2\x0f.cvms.AlgorithmR\talgorithm\x12?\n" +
|
||||
"\x10result_consumers\x18\x06 \x03(\v2\x14.cvms.ResultConsumerR\x0fresultConsumers\x124\n" +
|
||||
"\fagent_config\x18\a \x01(\v2\x11.cvms.AgentConfigR\vagentConfig\x12!\n" +
|
||||
"\x03kbs\x18\b \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"*\n" +
|
||||
"\fagent_config\x18\a \x01(\v2\x11.cvms.AgentConfigR\vagentConfig\"*\n" +
|
||||
"\x0eResultConsumer\x12\x18\n" +
|
||||
"\auserKey\x18\x01 \x01(\fR\auserKey\"\x99\x01\n" +
|
||||
"\auserKey\x18\x01 \x01(\fR\auserKey\"\xbc\x01\n" +
|
||||
"\aDataset\x12\x12\n" +
|
||||
"\x04hash\x18\x01 \x01(\fR\x04hash\x12\x18\n" +
|
||||
"\auserKey\x18\x02 \x01(\fR\auserKey\x12\x1a\n" +
|
||||
@@ -1504,13 +1511,15 @@ const file_agent_cvms_cvms_proto_rawDesc = "" +
|
||||
"\x06source\x18\x04 \x01(\v2\f.cvms.SourceR\x06source\x12\x1e\n" +
|
||||
"\n" +
|
||||
"decompress\x18\x05 \x01(\bR\n" +
|
||||
"decompress\"\x99\x01\n" +
|
||||
"decompress\x12!\n" +
|
||||
"\x03kbs\x18\x06 \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"\xbc\x01\n" +
|
||||
"\tAlgorithm\x12\x12\n" +
|
||||
"\x04hash\x18\x01 \x01(\fR\x04hash\x12\x18\n" +
|
||||
"\auserKey\x18\x02 \x01(\fR\auserKey\x12$\n" +
|
||||
"\x06source\x18\x03 \x01(\v2\f.cvms.SourceR\x06source\x12\x1b\n" +
|
||||
"\talgo_type\x18\x04 \x01(\tR\balgoType\x12\x1b\n" +
|
||||
"\talgo_args\x18\x05 \x03(\tR\balgoArgs\"x\n" +
|
||||
"\talgo_args\x18\x05 \x03(\tR\balgoArgs\x12!\n" +
|
||||
"\x03kbs\x18\x06 \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"x\n" +
|
||||
"\x06Source\x12\x12\n" +
|
||||
"\x04type\x18\x01 \x01(\tR\x04type\x12\x10\n" +
|
||||
"\x03url\x18\x02 \x01(\tR\x03url\x12*\n" +
|
||||
@@ -1591,16 +1600,17 @@ var file_agent_cvms_cvms_proto_depIdxs = []int32{
|
||||
14, // 15: cvms.ComputationRunReq.algorithm:type_name -> cvms.Algorithm
|
||||
12, // 16: cvms.ComputationRunReq.result_consumers:type_name -> cvms.ResultConsumer
|
||||
17, // 17: cvms.ComputationRunReq.agent_config:type_name -> cvms.AgentConfig
|
||||
16, // 18: cvms.ComputationRunReq.kbs:type_name -> cvms.KBSConfig
|
||||
15, // 19: cvms.Dataset.source:type_name -> cvms.Source
|
||||
15, // 18: cvms.Dataset.source:type_name -> cvms.Source
|
||||
16, // 19: cvms.Dataset.kbs:type_name -> cvms.KBSConfig
|
||||
15, // 20: cvms.Algorithm.source:type_name -> cvms.Source
|
||||
7, // 21: cvms.Service.Process:input_type -> cvms.ClientStreamMessage
|
||||
8, // 22: cvms.Service.Process:output_type -> cvms.ServerStreamMessage
|
||||
22, // [22:23] is the sub-list for method output_type
|
||||
21, // [21:22] is the sub-list for method input_type
|
||||
21, // [21:21] is the sub-list for extension type_name
|
||||
21, // [21:21] is the sub-list for extension extendee
|
||||
0, // [0:21] is the sub-list for field type_name
|
||||
16, // 21: cvms.Algorithm.kbs:type_name -> cvms.KBSConfig
|
||||
7, // 22: cvms.Service.Process:input_type -> cvms.ClientStreamMessage
|
||||
8, // 23: cvms.Service.Process:output_type -> cvms.ServerStreamMessage
|
||||
23, // [23:24] is the sub-list for method output_type
|
||||
22, // [22:23] is the sub-list for method input_type
|
||||
22, // [22:22] is the sub-list for extension type_name
|
||||
22, // [22:22] is the sub-list for extension extendee
|
||||
0, // [0:22] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_cvms_cvms_proto_init() }
|
||||
|
||||
@@ -92,7 +92,6 @@ message ComputationRunReq {
|
||||
Algorithm algorithm = 5;
|
||||
repeated ResultConsumer result_consumers = 6;
|
||||
AgentConfig agent_config = 7;
|
||||
KBSConfig kbs = 8; // Optional KBS configuration for remote resources
|
||||
}
|
||||
|
||||
message ResultConsumer {
|
||||
@@ -105,6 +104,7 @@ message Dataset {
|
||||
string filename = 3;
|
||||
Source source = 4; // Optional remote source for encrypted dataset
|
||||
bool decompress = 5;
|
||||
KBSConfig kbs = 6; // Optional KBS configuration override
|
||||
}
|
||||
|
||||
message Algorithm {
|
||||
@@ -113,6 +113,7 @@ message Algorithm {
|
||||
Source source = 3; // Optional remote source for encrypted algorithm
|
||||
string algo_type = 4;
|
||||
repeated string algo_args = 5;
|
||||
KBSConfig kbs = 6; // Optional KBS configuration override
|
||||
}
|
||||
|
||||
message Source {
|
||||
|
||||
@@ -107,7 +107,7 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
ID: "test-computation-1",
|
||||
Name: "Test Computation",
|
||||
Description: "A test computation",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x01, 0x02, 0x03},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -140,7 +140,7 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
ID: "test-computation-2",
|
||||
Name: "Test Computation 2",
|
||||
Description: "Another test computation",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x07, 0x08, 0x09},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -168,7 +168,7 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
cmp: agent.Computation{
|
||||
ID: "test-computation-3",
|
||||
Name: "Minimal Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x0d, 0x0e, 0x0f},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -244,7 +244,7 @@ func TestAgentServer_Stop(t *testing.T) {
|
||||
cmp := agent.Computation{
|
||||
ID: "test-stop-computation",
|
||||
Name: "Stop Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x19, 0x1a, 0x1b},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -303,7 +303,7 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) {
|
||||
cmp := agent.Computation{
|
||||
ID: "test-multiple-stop",
|
||||
Name: "Multiple Stop Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x1f, 0x20, 0x21},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -346,7 +346,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
cmp := agent.Computation{
|
||||
ID: "test-restart",
|
||||
Name: "Restart Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x25, 0x26, 0x27},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -377,7 +377,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
cmp2 := agent.Computation{
|
||||
ID: "test-restart-2",
|
||||
Name: "Restart Test 2",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x2b, 0x2c, 0x2d},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -426,7 +426,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
cmp: agent.Computation{
|
||||
ID: "valid-config-test",
|
||||
Name: "Valid Config Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x31, 0x32, 0x33},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -450,7 +450,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
cmp: agent.Computation{
|
||||
ID: "minimal-config-test",
|
||||
Name: "Minimal Config Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Algorithm: &agent.Algorithm{
|
||||
Hash: [32]byte{0x37, 0x38, 0x39},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
@@ -474,7 +474,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
cmp: agent.Computation{
|
||||
ID: "default-port-test",
|
||||
Name: "Default Port Test",
|
||||
Algorithm: agent.Algorithm{Hash: [32]byte{0x3d, 0x3e, 0x3f}, UserKey: pubKey},
|
||||
Algorithm: &agent.Algorithm{Hash: [32]byte{0x3d, 0x3e, 0x3f}, UserKey: pubKey},
|
||||
Datasets: []agent.Dataset{
|
||||
{Hash: [32]byte{0x40, 0x41, 0x42}, UserKey: pubKey},
|
||||
},
|
||||
|
||||
+20
-16
@@ -55,9 +55,11 @@ func TestDownloadAndDecryptGenericResource(t *testing.T) {
|
||||
logger: slog.Default(),
|
||||
resourceRegistry: registry,
|
||||
computation: Computation{
|
||||
KBS: KBSConfig{
|
||||
Enabled: true,
|
||||
URL: "http://mock-kbs",
|
||||
Algorithm: &Algorithm{
|
||||
KBS: &KBSConfig{
|
||||
Enabled: true,
|
||||
URL: "http://mock-kbs",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -71,7 +73,7 @@ func TestDownloadAndDecryptGenericResource(t *testing.T) {
|
||||
destPath := filepath.Join(os.TempDir(), "cocos-resources", "algo", "resource")
|
||||
mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, "some data").Once()
|
||||
|
||||
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "algo")
|
||||
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "", "algo")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("some data"), res.Data)
|
||||
mockDownloader.AssertExpectations(t)
|
||||
@@ -95,7 +97,7 @@ func TestDownloadAndDecryptGenericResource(t *testing.T) {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
svc.computation.KBS.URL = ts.URL
|
||||
svc.computation.Algorithm.KBS.URL = ts.URL
|
||||
|
||||
source := &ResourceSource{
|
||||
URL: "http://example.com/encrypted",
|
||||
@@ -105,7 +107,7 @@ func TestDownloadAndDecryptGenericResource(t *testing.T) {
|
||||
destPath := filepath.Join(os.TempDir(), "cocos-resources", "data", "encrypted")
|
||||
mockDownloader.On("Download", ctx, source.URL, destPath).Return(nil, string(ciphertext)).Once()
|
||||
|
||||
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, "data")
|
||||
res, err := svc.downloadAndDecryptGenericResource(ctx, source, resource.SourceTypeHTTP, svc.computation.Algorithm.KBS.URL, "data")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, res.Data)
|
||||
mockDownloader.AssertExpectations(t)
|
||||
@@ -113,7 +115,7 @@ func TestDownloadAndDecryptGenericResource(t *testing.T) {
|
||||
|
||||
t.Run("Registry not initialized", func(t *testing.T) {
|
||||
badSvc := &agentService{logger: slog.Default()}
|
||||
_, err := badSvc.downloadAndDecryptGenericResource(ctx, &ResourceSource{}, "http", "algo")
|
||||
_, err := badSvc.downloadAndDecryptGenericResource(ctx, &ResourceSource{}, "http", "", "algo")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "resource registry not initialized")
|
||||
})
|
||||
@@ -123,21 +125,23 @@ func TestGetKeyFromKBS(t *testing.T) {
|
||||
svc := &agentService{
|
||||
logger: slog.Default(),
|
||||
computation: Computation{
|
||||
KBS: KBSConfig{
|
||||
Enabled: true,
|
||||
Algorithm: &Algorithm{
|
||||
KBS: &KBSConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("KBS disabled", func(t *testing.T) {
|
||||
svc.computation.KBS.Enabled = false
|
||||
_, err := svc.getKeyFromKBS(ctx, "path")
|
||||
svc.computation.Algorithm.KBS.Enabled = false
|
||||
_, err := svc.getKeyFromKBS(ctx, "", "path")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Successful fetch", func(t *testing.T) {
|
||||
svc.computation.KBS.Enabled = true
|
||||
svc.computation.Algorithm.KBS.Enabled = true
|
||||
key := []byte("this is a 32-byte key!!!!!!!!!!!")
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Contains(t, r.URL.Path, "resource/path")
|
||||
@@ -145,9 +149,9 @@ func TestGetKeyFromKBS(t *testing.T) {
|
||||
_, _ = w.Write(key)
|
||||
}))
|
||||
defer ts.Close()
|
||||
svc.computation.KBS.URL = ts.URL
|
||||
svc.computation.Algorithm.KBS.URL = ts.URL
|
||||
|
||||
fetched, err := svc.getKeyFromKBS(ctx, "path")
|
||||
fetched, err := svc.getKeyFromKBS(ctx, svc.computation.Algorithm.KBS.URL, "path")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, key, fetched)
|
||||
})
|
||||
@@ -157,9 +161,9 @@ func TestGetKeyFromKBS(t *testing.T) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer ts.Close()
|
||||
svc.computation.KBS.URL = ts.URL
|
||||
svc.computation.Algorithm.KBS.URL = ts.URL
|
||||
|
||||
_, err := svc.getKeyFromKBS(ctx, "path")
|
||||
_, err := svc.getKeyFromKBS(ctx, svc.computation.Algorithm.KBS.URL, "path")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
+105
-47
@@ -93,6 +93,8 @@ var (
|
||||
// when accessing a protected resource.
|
||||
ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided")
|
||||
// ErrUndeclaredAlgorithm indicates algorithm was not declared in computation manifest.
|
||||
ErrUndeclaredAlgorithm = errors.New("algorithm not declared in computation manifest")
|
||||
// ErrUndeclaredDataset indicates dataset was not declared in computation manifest.
|
||||
ErrUndeclaredDataset = errors.New("dataset not declared in computation manifest")
|
||||
// ErrAllManifestItemsReceived indicates no new computation manifest items expected.
|
||||
ErrAllManifestItemsReceived = errors.New("all expected manifest Items have been received")
|
||||
@@ -245,27 +247,40 @@ func (as *agentService) InitComputation(ctx context.Context, cmp Computation) er
|
||||
|
||||
as.computation = cmp
|
||||
|
||||
// Debug: Log manifest details
|
||||
as.logger.Info("received computation manifest",
|
||||
"computation_id", cmp.ID,
|
||||
"kbs_enabled", cmp.KBS.Enabled,
|
||||
"kbs_url", cmp.KBS.URL,
|
||||
"algo_has_source", cmp.Algorithm.Source != nil,
|
||||
"dataset_count", len(cmp.Datasets))
|
||||
if cmp.Algorithm != nil {
|
||||
as.logger.Info("received computation manifest",
|
||||
"computation_id", cmp.ID,
|
||||
"algo_has_source", cmp.Algorithm.Source != nil,
|
||||
"algo_kbs_enabled", cmp.Algorithm.KBS != nil && cmp.Algorithm.KBS.Enabled,
|
||||
"algo_kbs_url", func() string {
|
||||
if cmp.Algorithm.KBS != nil {
|
||||
return cmp.Algorithm.KBS.URL
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
"dataset_count", len(cmp.Datasets))
|
||||
|
||||
if cmp.Algorithm.Source != nil {
|
||||
as.logger.Info("algorithm remote source configured",
|
||||
"url", cmp.Algorithm.Source.URL,
|
||||
"kbs_resource_path", cmp.Algorithm.Source.KBSResourcePath)
|
||||
if cmp.Algorithm.Source != nil {
|
||||
as.logger.Info("algorithm remote source configured",
|
||||
"url", cmp.Algorithm.Source.URL,
|
||||
"kbs_resource_path", cmp.Algorithm.Source.KBSResourcePath,
|
||||
"kbs_enabled", cmp.Algorithm.KBS != nil && cmp.Algorithm.KBS.Enabled,
|
||||
"kbs_url", func() string {
|
||||
if cmp.Algorithm.KBS != nil {
|
||||
return cmp.Algorithm.KBS.URL
|
||||
}
|
||||
return ""
|
||||
}())
|
||||
} else {
|
||||
as.logger.Info("algorithm remote source NOT configured - will wait for direct upload")
|
||||
}
|
||||
} else {
|
||||
as.logger.Info("algorithm remote source NOT configured - will wait for direct upload")
|
||||
as.logger.Info("received computation manifest (no algorithm)",
|
||||
"computation_id", cmp.ID,
|
||||
"dataset_count", len(cmp.Datasets))
|
||||
}
|
||||
|
||||
if cmp.KBS.Enabled {
|
||||
as.logger.Info("KBS is ENABLED", "url", cmp.KBS.URL)
|
||||
} else {
|
||||
as.logger.Info("KBS is NOT ENABLED")
|
||||
}
|
||||
as.logger.Info("Global KBS is NOT USED (per-resource configuration only)")
|
||||
|
||||
for i, d := range cmp.Datasets {
|
||||
if d.Source != nil {
|
||||
@@ -273,7 +288,14 @@ func (as *agentService) InitComputation(ctx context.Context, cmp Computation) er
|
||||
"index", i,
|
||||
"filename", d.Filename,
|
||||
"url", d.Source.URL,
|
||||
"kbs_resource_path", d.Source.KBSResourcePath)
|
||||
"kbs_resource_path", d.Source.KBSResourcePath,
|
||||
"kbs_enabled", d.KBS != nil && d.KBS.Enabled,
|
||||
"kbs_url", func() string {
|
||||
if d.KBS != nil {
|
||||
return d.KBS.URL
|
||||
}
|
||||
return ""
|
||||
}())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,21 +376,33 @@ func (as *agentService) downloadAlgorithmIfRemote(state statemachine.State) {
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
|
||||
// Debug: Log decision point
|
||||
// Check if algorithm should be downloaded from remote source
|
||||
if as.computation.Algorithm == nil {
|
||||
as.logger.Info("algorithm automatic download not triggered, (no algorithm in manifest)")
|
||||
return
|
||||
}
|
||||
|
||||
kbsEnabled := as.computation.Algorithm.KBS != nil && as.computation.Algorithm.KBS.Enabled
|
||||
kbsURL := ""
|
||||
if as.computation.Algorithm.KBS != nil {
|
||||
kbsURL = as.computation.Algorithm.KBS.URL
|
||||
}
|
||||
|
||||
as.logger.Info("checking if algorithm should be downloaded automatically",
|
||||
"algo_has_source", as.computation.Algorithm.Source != nil,
|
||||
"kbs_enabled", as.computation.KBS.Enabled)
|
||||
"kbs_enabled", kbsEnabled)
|
||||
|
||||
// Check if algorithm should be downloaded from remote source
|
||||
if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled {
|
||||
if as.computation.Algorithm.Source != nil && kbsEnabled {
|
||||
as.logger.Info("downloading algorithm from remote source",
|
||||
"url", as.computation.Algorithm.Source.URL,
|
||||
"kbs_resource_path", as.computation.Algorithm.Source.KBSResourcePath)
|
||||
"kbs_resource_path", as.computation.Algorithm.Source.KBSResourcePath,
|
||||
"kbs_url", kbsURL)
|
||||
|
||||
// Use background context for download operation
|
||||
ctx := context.Background()
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm")
|
||||
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, kbsURL, "algorithm")
|
||||
if err != nil {
|
||||
as.runError = fmt.Errorf("failed to download and decrypt algorithm: %w", err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
@@ -478,7 +512,8 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
|
||||
// Check if any datasets should be downloaded from remote sources
|
||||
hasRemoteDatasets := false
|
||||
for _, d := range as.computation.Datasets {
|
||||
if d.Source != nil && as.computation.KBS.Enabled {
|
||||
kbsEnabled := d.KBS != nil && d.KBS.Enabled
|
||||
if d.Source != nil && kbsEnabled {
|
||||
hasRemoteDatasets = true
|
||||
break
|
||||
}
|
||||
@@ -493,10 +528,16 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) {
|
||||
ctx := context.Background()
|
||||
for i := len(as.computation.Datasets) - 1; i >= 0; i-- {
|
||||
d := as.computation.Datasets[i]
|
||||
if d.Source != nil && as.computation.KBS.Enabled {
|
||||
as.logger.Info("downloading dataset from remote source", "filename", d.Filename)
|
||||
kbsEnabled := d.KBS != nil && d.KBS.Enabled
|
||||
kbsURL := ""
|
||||
if d.KBS != nil {
|
||||
kbsURL = d.KBS.URL
|
||||
}
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset")
|
||||
if d.Source != nil && kbsEnabled {
|
||||
as.logger.Info("downloading dataset from remote source", "filename", d.Filename, "kbs_url", kbsURL)
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, d.Source, kbsURL, "dataset")
|
||||
if err != nil {
|
||||
as.runError = fmt.Errorf("failed to download and decrypt dataset %s: %w", d.Filename, err)
|
||||
as.logger.Error(as.runError.Error())
|
||||
@@ -567,8 +608,8 @@ type DecryptedResource struct {
|
||||
// downloadAndDecryptResource downloads and decrypts a resource from various sources.
|
||||
// For OCI images, Skopeo handles download and CoCo Keyprovider handles decryption automatically.
|
||||
// For S3, GCS, HTTP/HTTPS: download + optional AES-256-GCM decryption with key from KBS.
|
||||
func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) {
|
||||
// Determine source type.
|
||||
func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, kbsURL, resourceType string) (*DecryptedResource, error) {
|
||||
// Determine source type
|
||||
sourceType := source.Type
|
||||
if sourceType == "" {
|
||||
sourceType = inferSourceType(source.URL)
|
||||
@@ -579,9 +620,9 @@ func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *
|
||||
|
||||
switch sourceType {
|
||||
case resource.SourceTypeOCIImage:
|
||||
return as.downloadAndDecryptOCIImage(ctx, source, resourceType)
|
||||
return as.downloadAndDecryptOCIImage(ctx, source, kbsURL, resourceType)
|
||||
case resource.SourceTypeS3, resource.SourceTypeGCS, resource.SourceTypeHTTPS, resource.SourceTypeHTTP:
|
||||
return as.downloadAndDecryptGenericResource(ctx, source, sourceType, resourceType)
|
||||
return as.downloadAndDecryptGenericResource(ctx, source, sourceType, kbsURL, resourceType)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported source type: %s", sourceType)
|
||||
}
|
||||
@@ -629,7 +670,7 @@ func inferSourceType(u string) string {
|
||||
|
||||
// downloadAndDecryptGenericResource downloads a resource using the appropriate downloader
|
||||
// from the registry and optionally decrypts it with AES-256-GCM using a key from KBS.
|
||||
func (as *agentService) downloadAndDecryptGenericResource(ctx context.Context, source *ResourceSource, sourceType, resourceType string) (*DecryptedResource, error) {
|
||||
func (as *agentService) downloadAndDecryptGenericResource(ctx context.Context, source *ResourceSource, sourceType, kbsURL, resourceType string) (*DecryptedResource, error) {
|
||||
as.logger.Info(fmt.Sprintf("downloading %s resource (type=%s url=%s encrypted=%t kbs_path=%s)",
|
||||
resourceType, sourceType, source.URL, source.Encrypted, source.KBSResourcePath))
|
||||
|
||||
@@ -664,9 +705,9 @@ func (as *agentService) downloadAndDecryptGenericResource(ctx context.Context, s
|
||||
if source.Encrypted && source.KBSResourcePath != "" {
|
||||
as.logger.Info("resource is encrypted, retrieving decryption key from KBS",
|
||||
"kbs_path", source.KBSResourcePath,
|
||||
"kbs_url", as.computation.KBS.URL)
|
||||
"kbs_url", kbsURL)
|
||||
|
||||
key, err := as.getKeyFromKBS(ctx, source.KBSResourcePath)
|
||||
key, err := as.getKeyFromKBS(ctx, kbsURL, source.KBSResourcePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve decryption key from KBS: %w", err)
|
||||
}
|
||||
@@ -688,13 +729,13 @@ func (as *agentService) downloadAndDecryptGenericResource(ctx context.Context, s
|
||||
// getKeyFromKBS retrieves a decryption key from the Key Broker Service.
|
||||
// It uses the Attestation Agent's GetResource capability to fetch the key
|
||||
// after performing remote attestation.
|
||||
func (as *agentService) getKeyFromKBS(ctx context.Context, resourcePath string) ([]byte, error) {
|
||||
if !as.computation.KBS.Enabled || as.computation.KBS.URL == "" {
|
||||
func (as *agentService) getKeyFromKBS(ctx context.Context, kbsURL, resourcePath string) ([]byte, error) {
|
||||
if kbsURL == "" {
|
||||
return nil, fmt.Errorf("KBS not configured or not enabled")
|
||||
}
|
||||
|
||||
// Construct KBS resource URL: kbs://<kbs_url>/<resource_path>
|
||||
kbsResourceURL := fmt.Sprintf("%s/kbs/v0/resource/%s", as.computation.KBS.URL, resourcePath)
|
||||
kbsResourceURL := fmt.Sprintf("%s/kbs/v0/resource/%s", kbsURL, resourcePath)
|
||||
|
||||
as.logger.Info("fetching key from KBS", "url", kbsResourceURL)
|
||||
|
||||
@@ -738,9 +779,9 @@ func kbsHTTPGet(ctx context.Context, url string) ([]byte, error) {
|
||||
}
|
||||
|
||||
// downloadAndDecryptOCIImage downloads and decrypts an OCI image using Skopeo and CoCo Keyprovider.
|
||||
func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) {
|
||||
as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s)",
|
||||
source.URL, source.Encrypted, source.KBSResourcePath))
|
||||
func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, kbsURL, resourceType string) (*DecryptedResource, error) {
|
||||
as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s kbs_url=%s)",
|
||||
source.URL, source.Encrypted, source.KBSResourcePath, kbsURL))
|
||||
|
||||
// Create Skopeo client
|
||||
if as.ociClient == nil {
|
||||
@@ -759,6 +800,7 @@ func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *
|
||||
URI: uri,
|
||||
Encrypted: source.Encrypted,
|
||||
KBSResourcePath: source.KBSResourcePath,
|
||||
KBSURL: kbsURL,
|
||||
}
|
||||
|
||||
// Pull and decrypt image
|
||||
@@ -779,7 +821,7 @@ func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *
|
||||
var err error
|
||||
|
||||
var files []string
|
||||
if resourceType == "algorithm" {
|
||||
if resourceType == "algorithm" && as.computation.Algorithm != nil {
|
||||
if as.computation.Algorithm.AlgoType == string(algorithm.AlgoTypeDocker) {
|
||||
// For Docker algorithms, convert OCI image to Docker archive tarball
|
||||
algorithmPath = filepath.Join(extractDir, "image.tar")
|
||||
@@ -869,10 +911,20 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
var algoData []byte
|
||||
|
||||
// Check if algorithm should be downloaded from remote source
|
||||
if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled {
|
||||
as.logger.Info("downloading algorithm from remote source")
|
||||
if as.computation.Algorithm == nil {
|
||||
return ErrUndeclaredAlgorithm
|
||||
}
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm")
|
||||
kbsEnabled := as.computation.Algorithm.KBS != nil && as.computation.Algorithm.KBS.Enabled
|
||||
kbsURL := ""
|
||||
if as.computation.Algorithm.KBS != nil {
|
||||
kbsURL = as.computation.Algorithm.KBS.URL
|
||||
}
|
||||
|
||||
if as.computation.Algorithm.Source != nil && kbsEnabled {
|
||||
as.logger.Info("downloading algorithm from remote source", "kbs_url", kbsURL)
|
||||
|
||||
res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, kbsURL, "algorithm")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download and decrypt algorithm: %w", err)
|
||||
}
|
||||
@@ -951,10 +1003,16 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
|
||||
// Check if any dataset should be downloaded from remote source
|
||||
matchedIndex := -1
|
||||
for i, d := range as.computation.Datasets {
|
||||
if d.Source != nil && as.computation.KBS.Enabled {
|
||||
as.logger.Info("downloading dataset from remote source", "filename", d.Filename)
|
||||
kbsEnabled := d.KBS != nil && d.KBS.Enabled
|
||||
kbsURL := ""
|
||||
if d.KBS != nil {
|
||||
kbsURL = d.KBS.URL
|
||||
}
|
||||
|
||||
downloadedData, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset")
|
||||
if d.Source != nil && kbsEnabled {
|
||||
as.logger.Info("downloading dataset from remote source", "filename", d.Filename, "kbs_url", kbsURL)
|
||||
|
||||
downloadedData, err := as.downloadAndDecryptResource(ctx, d.Source, kbsURL, "dataset")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download and decrypt dataset: %w", err)
|
||||
}
|
||||
|
||||
+41
-43
@@ -504,7 +504,7 @@ func testComputation(t *testing.T) Computation {
|
||||
Name: "sample computation",
|
||||
Description: "sample description",
|
||||
Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}},
|
||||
Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
|
||||
Algorithm: &Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
|
||||
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
|
||||
}
|
||||
}
|
||||
@@ -631,7 +631,7 @@ func TestStopComputationIntegration(t *testing.T) {
|
||||
computation := Computation{
|
||||
ID: "integration-test",
|
||||
Name: "Integration Test",
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: algoHash,
|
||||
Algorithm: algo,
|
||||
},
|
||||
@@ -718,28 +718,28 @@ func TestDownloadAndDecryptResource(t *testing.T) {
|
||||
|
||||
t.Run("unsupported URL format no type", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "abc://unsupported-format"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported source URL format")
|
||||
})
|
||||
|
||||
t.Run("ftp URL unsupported format", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "ftp://some-server/file"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported source URL format")
|
||||
})
|
||||
|
||||
t.Run("unsupported explicit source type", func(t *testing.T) {
|
||||
source := &ResourceSource{Type: "s3-bucket", URL: "s3://mybucket/algo"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported source type: s3-bucket")
|
||||
})
|
||||
|
||||
t.Run("bare OCI image name inferred as oci-image", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "ubuntu:latest"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
// Should route to OCI and fail at OCI client (which is nil or mock)
|
||||
assert.NotContains(t, err.Error(), "unsupported source URL format")
|
||||
@@ -747,7 +747,7 @@ func TestDownloadAndDecryptResource(t *testing.T) {
|
||||
|
||||
t.Run("bare registry image name inferred as oci-image", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "gcr.io/project/image:latest"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.NotContains(t, err.Error(), "unsupported source URL format")
|
||||
})
|
||||
@@ -755,7 +755,7 @@ func TestDownloadAndDecryptResource(t *testing.T) {
|
||||
t.Run("docker:// URL inferred as oci-image routes to skopeo", func(t *testing.T) {
|
||||
// This exercises the oci-image path; will fail at skopeo step
|
||||
source := &ResourceSource{URL: "docker://invalid.example.com/algo:latest"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
// Should be a skopeo or OCI error, not an "unsupported" error
|
||||
assert.NotContains(t, err.Error(), "unsupported source URL format")
|
||||
@@ -763,21 +763,21 @@ func TestDownloadAndDecryptResource(t *testing.T) {
|
||||
|
||||
t.Run("oci: URL inferred as oci-image routes to skopeo", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "oci:some-local-dir"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.NotContains(t, err.Error(), "unsupported source URL format")
|
||||
})
|
||||
|
||||
t.Run("explicit oci-image type routes to skopeo", func(t *testing.T) {
|
||||
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/algo:latest"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.NotContains(t, err.Error(), "unsupported source type")
|
||||
})
|
||||
|
||||
t.Run("dataset resource type with oci-image routes to skopeo", func(t *testing.T) {
|
||||
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/data:latest"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "dataset")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "dataset")
|
||||
require.Error(t, err)
|
||||
assert.NotContains(t, err.Error(), "unsupported source type")
|
||||
})
|
||||
@@ -785,7 +785,7 @@ func TestDownloadAndDecryptResource(t *testing.T) {
|
||||
t.Run("https inferred routes to registry", func(t *testing.T) {
|
||||
// Mock registry to fail predictably
|
||||
source := &ResourceSource{URL: "https://example.com/file.bin"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
// It should complain about registry missing, because the test service does not initialize the registry
|
||||
assert.Contains(t, err.Error(), "resource registry not initialized")
|
||||
@@ -793,7 +793,7 @@ func TestDownloadAndDecryptResource(t *testing.T) {
|
||||
|
||||
t.Run("s3 inferred routes to registry", func(t *testing.T) {
|
||||
source := &ResourceSource{URL: "s3://bucket/key"}
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||||
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "resource registry not initialized")
|
||||
})
|
||||
@@ -823,10 +823,10 @@ func TestDownloadAlgorithmIfRemote(t *testing.T) {
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Source: &ResourceSource{URL: "docker://registry/algo:latest"},
|
||||
KBS: &KBSConfig{Enabled: false},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: false},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
@@ -843,13 +843,13 @@ func TestDownloadAlgorithmIfRemote(t *testing.T) {
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://invalid.example.com/algo:latest",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
@@ -867,12 +867,12 @@ func TestDownloadAlgorithmIfRemote(t *testing.T) {
|
||||
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Source: &ResourceSource{
|
||||
URL: "http://unsupported-format/algo",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
@@ -895,7 +895,6 @@ func TestDownloadDatasetsIfRemote(t *testing.T) {
|
||||
Datasets: []Dataset{
|
||||
{Hash: dataHash, Filename: "data.csv"},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
@@ -912,7 +911,6 @@ func TestDownloadDatasetsIfRemote(t *testing.T) {
|
||||
svc := newTestAgentService(sm, eventsSvc)
|
||||
svc.computation = Computation{
|
||||
Datasets: []Dataset{},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
@@ -931,9 +929,9 @@ func TestDownloadDatasetsIfRemote(t *testing.T) {
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Source: &ResourceSource{URL: "docker://registry/data:latest"},
|
||||
KBS: &KBSConfig{Enabled: false},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: false},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
@@ -956,9 +954,9 @@ func TestDownloadDatasetsIfRemote(t *testing.T) {
|
||||
Type: "oci-image",
|
||||
URL: "docker://invalid.example.com/data:latest",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
@@ -978,11 +976,11 @@ func TestDownloadDatasetsIfRemote(t *testing.T) {
|
||||
{
|
||||
Filename: "data.csv",
|
||||
Source: &ResourceSource{
|
||||
URL: "ftp://unsupported/data",
|
||||
URL: "http://unsupported-format/data",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
@@ -1147,15 +1145,15 @@ func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
|
||||
algoHash := sha3.Sum256(algoContent)
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: algoHash,
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-success",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
// We need to bypass oci.ExtractAlgorithm by manually creating what it would create
|
||||
@@ -1202,15 +1200,15 @@ func TestDownloadAlgorithmIfRemote_Docker_Success(t *testing.T) {
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
AlgoType: "docker",
|
||||
Hash: dummyHash,
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-docker-success",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
@@ -1310,9 +1308,9 @@ func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-success",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
@@ -1374,9 +1372,9 @@ func TestDownloadDatasetsIfRemote_Decompress(t *testing.T) {
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-decompress",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
err = os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
@@ -1418,15 +1416,15 @@ func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) {
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: sha3.Sum256([]byte("expected content")),
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-hash-mismatch",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
@@ -1454,15 +1452,15 @@ func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) {
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: sha3.Sum256([]byte(algoContent)),
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-create-fail",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
@@ -1487,14 +1485,14 @@ func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) {
|
||||
svc.ociClient = mockOCI
|
||||
|
||||
svc.computation = Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/image",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||||
@@ -1541,9 +1539,9 @@ func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) {
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-create-fail",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||||
@@ -1580,9 +1578,9 @@ func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) {
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-mismatch",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
@@ -1628,9 +1626,9 @@ func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) {
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-unzip-fail",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
@@ -1676,15 +1674,15 @@ func TestAlgo_RemoteSource(t *testing.T) {
|
||||
sm: sm,
|
||||
ociClient: mockOCI,
|
||||
computation: Computation{
|
||||
Algorithm: Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: algoHash,
|
||||
AlgoType: "python",
|
||||
Source: &ResourceSource{
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/algo-remote",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1735,9 +1733,9 @@ func TestData_RemoteSource(t *testing.T) {
|
||||
Type: "oci-image",
|
||||
URL: "docker://test/data-remote",
|
||||
},
|
||||
KBS: &KBSConfig{Enabled: true},
|
||||
},
|
||||
},
|
||||
KBS: KBSConfig{Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+6
-6
@@ -29,7 +29,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
|
||||
algorithm, err := os.Open(algorithmFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading algorithm file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading algorithm file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
if requirementsFile != "" {
|
||||
req, err = os.Open(requirementsFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading requirments file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading requirments file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer req.Close()
|
||||
@@ -57,7 +57,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
|
||||
privKeyFile, err := os.ReadFile(args[1])
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -65,14 +65,14 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
|
||||
privKey, err := decodeKey(pemBlock)
|
||||
if err != nil {
|
||||
printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
||||
|
||||
if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algorithm, req, privKey); err != nil {
|
||||
printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err)
|
||||
cli.printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+11
-11
@@ -95,12 +95,12 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
if err := cobra.OnlyValidArgs(cmd, args); err != nil {
|
||||
printError(cmd, "Bad attestation type: %v ❌ ", err)
|
||||
cli.printError(cmd, "Bad attestation type: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -180,7 +180,7 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
|
||||
attestationFile, err := os.Create(filename)
|
||||
if err != nil {
|
||||
printError(cmd, "Error creating attestation file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error creating attestation file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -189,27 +189,27 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
if attestationType == AzureToken {
|
||||
err := cli.agentSDK.AttestationToken(cmd.Context(), fixedVtpmNonceByte, int(attType), attestationFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Failed to get attestation token due to error: %v ❌", err)
|
||||
cli.printError(cmd, "Failed to get attestation token due to error: %v ❌", err)
|
||||
return
|
||||
}
|
||||
returnJsonAzureToken = !getAzureTokenJWT
|
||||
} else {
|
||||
err := cli.agentSDK.Attestation(cmd.Context(), fixedReportData, fixedVtpmNonceByte, int(attType), attestationFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Failed to get attestation due to error: %v ❌", err)
|
||||
cli.printError(cmd, "Failed to get attestation due to error: %v ❌", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := attestationFile.Close(); err != nil {
|
||||
printError(cmd, "Error closing attestation file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error closing attestation file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if getTextProtoAttestationReport || returnJsonAzureToken {
|
||||
result, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading attestation file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading attestation file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -217,7 +217,7 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
case SNP:
|
||||
result, err = attestationToJSON(result)
|
||||
if err != nil {
|
||||
printError(cmd, "Error converting SNP attestation to JSON: %v ❌", err)
|
||||
cli.printError(cmd, "Error converting SNP attestation to JSON: %v ❌", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
var attvTPM tpmAttest.Attestation
|
||||
err = proto.Unmarshal(result, &attvTPM)
|
||||
if err != nil {
|
||||
printError(cmd, "Failed to unmarshal the attestation report: %v ❌", err)
|
||||
cli.printError(cmd, "Failed to unmarshal the attestation report: %v ❌", err)
|
||||
return
|
||||
}
|
||||
result = []byte(marshalOptions.Format(&attvTPM))
|
||||
@@ -237,13 +237,13 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
case AzureToken:
|
||||
result, err = decodeJWTToJSON(result)
|
||||
if err != nil {
|
||||
printError(cmd, "Error decoding Azure token: %v ❌", err)
|
||||
cli.printError(cmd, "Error decoding Azure token: %v ❌", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filename, result, 0o644); err != nil {
|
||||
printError(cmd, "Error writing attestation file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error writing attestation file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func (cli *CLI) NewDownloadGCPOvmfFile() *cobra.Command {
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
attestationBin, err := os.ReadFile(args[0])
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading attestation report file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading attestation report file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -52,12 +52,12 @@ func (cli *CLI) NewDownloadGCPOvmfFile() *cobra.Command {
|
||||
|
||||
if isJsonAttestation {
|
||||
if err := protojson.Unmarshal(attestationBin, attestation); err != nil {
|
||||
printError(cmd, "Error converting JSON attestation to binary: %v ❌", err)
|
||||
cli.printError(cmd, "Error converting JSON attestation to binary: %v ❌", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := proto.Unmarshal(attestationBin, attestation); err != nil {
|
||||
printError(cmd, "Error unmarshaling attestation report: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error unmarshaling attestation report: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -66,32 +66,32 @@ func (cli *CLI) NewDownloadGCPOvmfFile() *cobra.Command {
|
||||
|
||||
measurement, err := gcp.Extract384BitMeasurement(attestationPB)
|
||||
if err != nil {
|
||||
printError(cmd, "Error extracting 384-bit measurement: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error extracting 384-bit measurement: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
launchEndorsement, err := gcp.GetLaunchEndorsement(cmd.Context(), measurement)
|
||||
if err != nil {
|
||||
printError(cmd, "Error getting launch endorsement: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error getting launch endorsement: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
ovmf, err := gcp.DownloadOvmfFile(cmd.Context(), fmt.Sprintf("%x", launchEndorsement.Digest))
|
||||
if err != nil {
|
||||
printError(cmd, "Error downloading OVMF file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error downloading OVMF file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
sum384 := sha512.Sum384(ovmf)
|
||||
|
||||
if !bytes.Equal(sum384[:], launchEndorsement.Digest) {
|
||||
printError(cmd, "Error OVMF file does not match the measurement: %v ❌ ", fmt.Errorf("digest mismatch"))
|
||||
cli.printError(cmd, "Error OVMF file does not match the measurement: %v ❌ ", fmt.Errorf("digest mismatch"))
|
||||
} else {
|
||||
cmd.Println("OVMF firmware in vm is unmodified ✅")
|
||||
}
|
||||
|
||||
if err := os.WriteFile("ovmf.fd", ovmf, filePermission); err != nil {
|
||||
printError(cmd, "Error writing OVMF file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error writing OVMF file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+13
-14
@@ -14,11 +14,6 @@ import (
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
var (
|
||||
ismanifest bool
|
||||
toBase64 bool
|
||||
)
|
||||
|
||||
func (cli *CLI) NewFileHashCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "checksum",
|
||||
@@ -28,29 +23,33 @@ func (cli *CLI) NewFileHashCmd() *cobra.Command {
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
path := args[0]
|
||||
|
||||
if ismanifest {
|
||||
if cli.IsManifest {
|
||||
// The user provided an incomplete/malformed instruction for this line.
|
||||
// Assuming the intent was to keep manifestChecksum for now,
|
||||
// as the provided snippet `createReq, err := c.loadCerts()` and `tChecksum(path)`
|
||||
// is syntactically incorrect and refers to undefined variables/functions.
|
||||
hash, err := manifestChecksum(path)
|
||||
if err != nil {
|
||||
printError(cmd, "Error computing hash: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error computing hash: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("Hash of manifest file:", hashOut(hash))
|
||||
cmd.Println("Hash of manifest file:", cli.hashOut(hash))
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := internal.ChecksumHex(path)
|
||||
if err != nil {
|
||||
printError(cmd, "Error computing hash: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error computing hash: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("Hash of file:", hashOut(hash))
|
||||
cmd.Println("Hash of file:", cli.hashOut(hash))
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&ismanifest, "manifest", "m", false, "Compute the hash of the manifest file")
|
||||
cmd.Flags().BoolVarP(&toBase64, "base64", "b", false, "Output the hash in base64")
|
||||
cmd.Flags().BoolVarP(&cli.IsManifest, "manifest", "m", false, "Compute the hash of the manifest file")
|
||||
cmd.Flags().BoolVarP(&cli.ToBase64, "base64", "b", false, "Output the hash in base64")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -77,8 +76,8 @@ func manifestChecksum(path string) (string, error) {
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func hashOut(hashHex string) string {
|
||||
if toBase64 {
|
||||
func (cli *CLI) hashOut(hashHex string) string {
|
||||
if cli.ToBase64 {
|
||||
return hexToBase64(hashHex)
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestManifestChecksum(t *testing.T) {
|
||||
"name": "Example Computation",
|
||||
"description": "This is an example computation"
|
||||
}`,
|
||||
expectedSum: "4ff220c22b2bdf6d5bb4c32dc0f24b5183cfef9b8200dfdf6109c230c8c90394",
|
||||
expectedSum: "c8344428fca26ed8c4dfee031cf1459ebcf81bd6cb5f4318f72b3bbd68782146",
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
@@ -220,8 +220,8 @@ func TestHashOut(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
toBase64 = tc.toBase64
|
||||
out := hashOut(tc.hashHex)
|
||||
c := &CLI{ToBase64: tc.toBase64}
|
||||
out := c.hashOut(tc.hashHex)
|
||||
if out != tc.expectedOut {
|
||||
t.Errorf("Expected %s, got %s", tc.expectedOut, out)
|
||||
}
|
||||
|
||||
+7
-7
@@ -27,7 +27,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
|
||||
f, err := os.Stat(datasetPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
cmd.Println("Detected directory, zipping dataset...")
|
||||
dataset, err = internal.ZipDirectoryToTempFile(datasetPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error zipping dataset directory: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error zipping dataset directory: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer dataset.Close()
|
||||
@@ -55,7 +55,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
} else {
|
||||
dataset, err = os.Open(datasetPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer dataset.Close()
|
||||
@@ -63,7 +63,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
|
||||
privKeyFile, err := os.ReadFile(args[1])
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -71,13 +71,13 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
|
||||
privKey, err := decodeKey(pemBlock)
|
||||
if err != nil {
|
||||
printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
||||
if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataset, path.Base(datasetPath), privKey); err != nil {
|
||||
printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err)
|
||||
cli.printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+2
-2
@@ -40,8 +40,8 @@ func decodeErros(err error) error {
|
||||
}
|
||||
}
|
||||
|
||||
func printError(cmd *cobra.Command, message string, err error) {
|
||||
if !Verbose {
|
||||
func (c *CLI) printError(cmd *cobra.Command, message string, err error) {
|
||||
if !c.Verbose {
|
||||
err = decodeErros(err)
|
||||
}
|
||||
msg := color.New(color.FgRed).Sprintf(message, err)
|
||||
|
||||
+2
-2
@@ -95,12 +95,12 @@ func TestPrintError(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
Verbose = tt.verbose
|
||||
c := &CLI{Verbose: tt.verbose}
|
||||
cmd := &cobra.Command{}
|
||||
buf := new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
|
||||
printError(cmd, tt.message, tt.err)
|
||||
c.printError(cmd, tt.message, tt.err)
|
||||
|
||||
if got := buf.String(); got != tt.expected {
|
||||
t.Errorf("printError() output = %q, want %q", got, tt.expected)
|
||||
|
||||
@@ -25,7 +25,7 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command {
|
||||
Example: "ima-measurements <optional_file_name>",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -38,14 +38,14 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command {
|
||||
|
||||
imaMeasurementsFile, err := os.Create(filename)
|
||||
if err != nil {
|
||||
printError(cmd, "Error creating imaMeasurements file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error creating imaMeasurements file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer imaMeasurementsFile.Close()
|
||||
|
||||
pcr10, err := cli.agentSDK.IMAMeasurements(cmd.Context(), imaMeasurementsFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Error retrieving Linux IMA measurements file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error retrieving Linux IMA measurements file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command {
|
||||
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
printError(cmd, "Failed to open file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Failed to open file: %v ❌ ", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
@@ -76,7 +76,7 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command {
|
||||
|
||||
digest, err := hex.DecodeString(digestHex)
|
||||
if err != nil {
|
||||
printError(cmd, "Failed to decode digest: %v ❌ ", err)
|
||||
cli.printError(cmd, "Failed to decode digest: %v ❌ ", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
if hex.EncodeToString(pcr10) != hex.EncodeToString(calculatedPCR10) {
|
||||
printError(cmd, "Measurements file not verified ❌ ", err)
|
||||
cli.printError(cmd, "Measurements file not verified ❌ ", err)
|
||||
} else {
|
||||
cmd.Println(color.New(color.FgGreen).Sprintf("Measurements file verified!"))
|
||||
}
|
||||
|
||||
+11
-13
@@ -27,8 +27,6 @@ const (
|
||||
ED25519 = "ed25519"
|
||||
)
|
||||
|
||||
var KeyType string
|
||||
|
||||
func (cli *CLI) NewKeysCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "keys",
|
||||
@@ -38,60 +36,60 @@ func (cli *CLI) NewKeysCmd() *cobra.Command {
|
||||
Example: "./build/cocos-cli keys -k rsa",
|
||||
Args: cobra.ExactArgs(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
switch KeyType {
|
||||
switch cli.KeyType {
|
||||
case ECDSA:
|
||||
privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privEcdsaKey.PublicKey)
|
||||
if err != nil {
|
||||
printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := generateAndWriteKeys(privEcdsaKey, pubKeyBytes, ecdsaKeyType); err != nil {
|
||||
printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
case ED25519:
|
||||
pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
pubKey, err := x509.MarshalPKIXPublicKey(pubEd25519Key)
|
||||
if err != nil {
|
||||
printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
if err := generateAndWriteKeys(privEd25519Key, pubKey, ed25519KeyType); err != nil {
|
||||
printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, keyBitSize)
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
|
||||
if err != nil {
|
||||
printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error marshalling public key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
if err := generateAndWriteKeys(privKey, pubKeyBytes, rsaKeyType); err != nil {
|
||||
printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Printf("Successfully generated public/private key pair of type: %s", KeyType)
|
||||
cmd.Printf("Successfully generated public/private key pair of type: %s", cli.KeyType)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+2
-2
@@ -37,8 +37,8 @@ func TestGenerateAndWriteKeys(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
KeyType = tt.keyType
|
||||
cmd := (&CLI{}).NewKeysCmd()
|
||||
c := &CLI{KeyType: tt.keyType}
|
||||
cmd := c.NewKeysCmd()
|
||||
cmd.Run(cmd, []string{})
|
||||
|
||||
if _, err := os.Stat(privateKeyFile); os.IsNotExist(err) {
|
||||
|
||||
+35
-51
@@ -4,7 +4,6 @@ package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -21,21 +20,6 @@ const (
|
||||
ttlFlag = "ttl"
|
||||
)
|
||||
|
||||
var (
|
||||
agentCVMServerUrl string
|
||||
agentCVMServerCA string
|
||||
agentCVMClientKey string
|
||||
agentCVMClientCrt string
|
||||
agentCVMCaUrl string
|
||||
agentLogLevel string
|
||||
ttl time.Duration
|
||||
awsAccessKeyId string
|
||||
awsSecretAccessKey string
|
||||
awsEndpointUrl string
|
||||
awsRegion string
|
||||
aaKbsParams string
|
||||
)
|
||||
|
||||
func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "create-vm",
|
||||
@@ -44,41 +28,41 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if c.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
c.printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
}
|
||||
if c.managerClient == nil {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", err)
|
||||
c.printError(cmd, "Failed to connect to manager: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
createReq, err := loadCerts()
|
||||
createReq, err := c.loadCerts()
|
||||
if err != nil {
|
||||
printError(cmd, "Error loading certs: %v ❌ ", err)
|
||||
c.printError(cmd, "Error loading certs: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
createReq.AgentCvmServerUrl = agentCVMServerUrl
|
||||
createReq.AgentLogLevel = agentLogLevel
|
||||
createReq.AgentCvmCaUrl = agentCVMCaUrl
|
||||
createReq.AwsAccessKeyId = awsAccessKeyId
|
||||
createReq.AwsSecretAccessKey = awsSecretAccessKey
|
||||
createReq.AwsEndpointUrl = awsEndpointUrl
|
||||
createReq.AwsRegion = awsRegion
|
||||
createReq.AaKbsParams = aaKbsParams
|
||||
createReq.AgentCvmServerUrl = c.AgentVM.CVMServerURL
|
||||
createReq.AgentLogLevel = c.AgentVM.LogLevel
|
||||
createReq.AgentCvmCaUrl = c.AgentVM.CVMCaURL
|
||||
createReq.AwsAccessKeyId = c.AWS.AccessKeyID
|
||||
createReq.AwsSecretAccessKey = c.AWS.SecretAccessKey
|
||||
createReq.AwsEndpointUrl = c.AWS.EndpointURL
|
||||
createReq.AwsRegion = c.AWS.Region
|
||||
createReq.AaKbsParams = c.Attestation.KbsParams
|
||||
|
||||
if ttl > 0 {
|
||||
createReq.Ttl = ttl.String()
|
||||
if c.AgentVM.Ttl > 0 {
|
||||
createReq.Ttl = c.AgentVM.Ttl.String()
|
||||
}
|
||||
|
||||
cmd.Println("🔗 Creating a new virtual machine")
|
||||
|
||||
res, err := c.managerClient.CreateVm(cmd.Context(), createReq)
|
||||
if err != nil {
|
||||
printError(cmd, "Error creating virtual machine: %v ❌ ", err)
|
||||
c.printError(cmd, "Error creating virtual machine: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -86,20 +70,20 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&agentCVMServerUrl, serverURL, "", "CVM server URL")
|
||||
cmd.Flags().StringVar(&agentCVMServerCA, serverCA, "", "CVM server CA")
|
||||
cmd.Flags().StringVar(&agentCVMClientKey, clientKey, "", "CVM client key")
|
||||
cmd.Flags().StringVar(&agentCVMClientCrt, clientCrt, "", "CVM client crt")
|
||||
cmd.Flags().StringVar(&agentCVMCaUrl, caUrl, "", "CVM CA service URL")
|
||||
cmd.Flags().StringVar(&agentLogLevel, logLevel, "", "Agent Log level")
|
||||
cmd.Flags().DurationVar(&ttl, ttlFlag, 0, "TTL for the VM")
|
||||
cmd.Flags().StringVar(&awsAccessKeyId, "aws-access-key-id", "", "AWS Access Key ID for S3/MinIO")
|
||||
cmd.Flags().StringVar(&awsSecretAccessKey, "aws-secret-access-key", "", "AWS Secret Access Key for S3/MinIO")
|
||||
cmd.Flags().StringVar(&awsEndpointUrl, "aws-endpoint-url", "", "AWS Endpoint URL (for MinIO or custom S3)")
|
||||
cmd.Flags().StringVar(&awsRegion, "aws-region", "", "AWS Region")
|
||||
cmd.Flags().StringVar(&aaKbsParams, "aa-kbs-params", "", "Attestation Agent KBS Parameters (e.g. protocol=http,type=kbs,url=http://... or just type=sample)")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMServerURL, serverURL, "", "CVM server URL")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMServerCA, serverCA, "", "CVM server CA")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMClientKey, clientKey, "", "CVM client key")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMClientCrt, clientCrt, "", "CVM client crt")
|
||||
cmd.Flags().StringVar(&c.AgentVM.CVMCaURL, caUrl, "", "CVM CA service URL")
|
||||
cmd.Flags().StringVar(&c.AgentVM.LogLevel, logLevel, "", "Agent Log level")
|
||||
cmd.Flags().DurationVar(&c.AgentVM.Ttl, ttlFlag, 0, "TTL for the VM")
|
||||
cmd.Flags().StringVar(&c.AWS.AccessKeyID, "aws-access-key-id", "", "AWS Access Key ID for S3/MinIO")
|
||||
cmd.Flags().StringVar(&c.AWS.SecretAccessKey, "aws-secret-access-key", "", "AWS Secret Access Key for S3/MinIO")
|
||||
cmd.Flags().StringVar(&c.AWS.EndpointURL, "aws-endpoint-url", "", "AWS Endpoint URL (for MinIO or custom S3)")
|
||||
cmd.Flags().StringVar(&c.AWS.Region, "aws-region", "", "AWS Region")
|
||||
cmd.Flags().StringVar(&c.Attestation.KbsParams, "aa-kbs-params", "", "Attestation Agent KBS Parameters (e.g. protocol=http,type=kbs,url=http://... or just type=sample)")
|
||||
if err := cmd.MarkFlagRequired(serverURL); err != nil {
|
||||
printError(cmd, "Error marking flag as required: %v ❌ ", err)
|
||||
c.printError(cmd, "Error marking flag as required: %v ❌ ", err)
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -114,12 +98,12 @@ func (c *CLI) NewRemoveVMCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if c.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
c.printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
}
|
||||
if c.managerClient == nil {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", err)
|
||||
c.printError(cmd, "Failed to connect to manager: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -129,7 +113,7 @@ func (c *CLI) NewRemoveVMCmd() *cobra.Command {
|
||||
|
||||
_, err := c.managerClient.RemoveVm(cmd.Context(), &manager.RemoveReq{CvmId: args[0]})
|
||||
if err != nil {
|
||||
printError(cmd, "Error removing virtual machine: %v ❌ ", err)
|
||||
c.printError(cmd, "Error removing virtual machine: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -146,18 +130,18 @@ func fileReader(path string) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
func loadCerts() (*manager.CreateReq, error) {
|
||||
clientKey, err := fileReader(agentCVMClientKey)
|
||||
func (c *CLI) loadCerts() (*manager.CreateReq, error) {
|
||||
clientKey, err := fileReader(c.AgentVM.CVMClientKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientCrt, err := fileReader(agentCVMClientCrt)
|
||||
clientCrt, err := fileReader(c.AgentVM.CVMClientCrt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverCA, err := fileReader(agentCVMServerCA)
|
||||
serverCA, err := fileReader(c.AgentVM.CVMServerCA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+25
-37
@@ -392,7 +392,7 @@ func TestLoadCerts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFiles func(string) error
|
||||
setupGlobal func(string)
|
||||
setupCLI func(string, *CLI)
|
||||
expectError bool
|
||||
validate func(*testing.T, *manager.CreateReq)
|
||||
}{
|
||||
@@ -411,10 +411,10 @@ func TestLoadCerts(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
agentCVMServerCA = filepath.Join(tmpDir, "server.ca")
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
c.AgentVM.CVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
c.AgentVM.CVMServerCA = filepath.Join(tmpDir, "server.ca")
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, req *manager.CreateReq) {
|
||||
@@ -428,10 +428,10 @@ func TestLoadCerts(t *testing.T) {
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = ""
|
||||
agentCVMClientCrt = ""
|
||||
agentCVMServerCA = ""
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = ""
|
||||
c.AgentVM.CVMClientCrt = ""
|
||||
c.AgentVM.CVMServerCA = ""
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, req *manager.CreateReq) {
|
||||
@@ -445,10 +445,10 @@ func TestLoadCerts(t *testing.T) {
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // Don't create client key file
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "nonexistent.key")
|
||||
agentCVMClientCrt = ""
|
||||
agentCVMServerCA = ""
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = filepath.Join(tmpDir, "nonexistent.key")
|
||||
c.AgentVM.CVMClientCrt = ""
|
||||
c.AgentVM.CVMServerCA = ""
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
@@ -458,10 +458,10 @@ func TestLoadCerts(t *testing.T) {
|
||||
// Create client key but not cert
|
||||
return os.WriteFile(filepath.Join(tmpDir, "client.key"), []byte("key-content"), 0o644)
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "nonexistent.crt")
|
||||
agentCVMServerCA = ""
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
c.AgentVM.CVMClientCrt = filepath.Join(tmpDir, "nonexistent.crt")
|
||||
c.AgentVM.CVMServerCA = ""
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
@@ -479,10 +479,10 @@ func TestLoadCerts(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
agentCVMServerCA = filepath.Join(tmpDir, "nonexistent.ca")
|
||||
setupCLI: func(tmpDir string, c *CLI) {
|
||||
c.AgentVM.CVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
c.AgentVM.CVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
c.AgentVM.CVMServerCA = filepath.Join(tmpDir, "nonexistent.ca")
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
@@ -497,22 +497,10 @@ func TestLoadCerts(t *testing.T) {
|
||||
err = tt.setupFiles(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store original global variables
|
||||
origClientKey := agentCVMClientKey
|
||||
origClientCrt := agentCVMClientCrt
|
||||
origServerCA := agentCVMServerCA
|
||||
c := &CLI{}
|
||||
tt.setupCLI(tmpDir, c)
|
||||
|
||||
// Setup global variables for test
|
||||
tt.setupGlobal(tmpDir)
|
||||
|
||||
// Restore original values after test
|
||||
defer func() {
|
||||
agentCVMClientKey = origClientKey
|
||||
agentCVMClientCrt = origClientCrt
|
||||
agentCVMServerCA = origServerCA
|
||||
}()
|
||||
|
||||
result, err := loadCerts()
|
||||
result, err := c.loadCerts()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
@@ -592,7 +580,7 @@ func TestTTLHandling(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedTTL, ttl)
|
||||
assert.Equal(t, tt.expectedTTL, mockCLI.AgentVM.Ttl)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
+6
-6
@@ -24,7 +24,7 @@ func (cli *CLI) NewResultsCmd() *cobra.Command {
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -32,14 +32,14 @@ func (cli *CLI) NewResultsCmd() *cobra.Command {
|
||||
|
||||
privKeyFile, err := os.ReadFile(args[0])
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error reading private key file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var outputPath string
|
||||
if outputDir != "" {
|
||||
if err := os.MkdirAll(outputDir, 0o755); err != nil {
|
||||
printError(cmd, "Error creating output directory: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error creating output directory: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
outputPath = filepath.Join(outputDir, filename)
|
||||
@@ -56,19 +56,19 @@ func (cli *CLI) NewResultsCmd() *cobra.Command {
|
||||
|
||||
privKey, err := decodeKey(pemBlock)
|
||||
if err != nil {
|
||||
printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
resultFile, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error creating result file: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error creating result file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer resultFile.Close()
|
||||
|
||||
if err = cli.agentSDK.Result(cmd.Context(), privKey, resultFile); err != nil {
|
||||
printError(cmd, "Error retrieving computation result: %v ❌ ", err)
|
||||
cli.printError(cmd, "Error retrieving computation result: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+28
-1
@@ -4,6 +4,7 @@ package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
@@ -15,7 +16,26 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk"
|
||||
)
|
||||
|
||||
var Verbose bool
|
||||
type AgentVMConfig struct {
|
||||
CVMServerURL string
|
||||
CVMServerCA string
|
||||
CVMClientKey string
|
||||
CVMClientCrt string
|
||||
CVMCaURL string
|
||||
LogLevel string
|
||||
Ttl time.Duration
|
||||
}
|
||||
|
||||
type AWSConfig struct {
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
EndpointURL string
|
||||
Region string
|
||||
}
|
||||
|
||||
type AttestationConfig struct {
|
||||
KbsParams string
|
||||
}
|
||||
|
||||
type CLI struct {
|
||||
agentSDK sdk.SDK
|
||||
@@ -25,6 +45,13 @@ type CLI struct {
|
||||
managerClient manager.ManagerServiceClient
|
||||
connectErr error
|
||||
measurement cmdconfig.MeasurementProvider
|
||||
Verbose bool
|
||||
IsManifest bool
|
||||
ToBase64 bool
|
||||
KeyType string
|
||||
AgentVM AgentVMConfig
|
||||
AWS AWSConfig
|
||||
Attestation AttestationConfig
|
||||
}
|
||||
|
||||
func New(agentConfig clients.AttestedClientConfig, managerConfig clients.StandardClientConfig, measurement cmdconfig.MeasurementProvider) *CLI {
|
||||
|
||||
@@ -45,17 +45,6 @@ type config struct {
|
||||
EATIssuer string `env:"ATTESTATION_EAT_ISSUER" envDefault:"cocos-attestation-service"`
|
||||
UseCCAttestationAgent bool `env:"USE_CC_ATTESTATION_AGENT" envDefault:"false"`
|
||||
CCAgentAddress string `env:"CC_AGENT_ADDRESS" envDefault:"127.0.0.1:50002"`
|
||||
|
||||
// Future KBS Integration Configuration
|
||||
// When KBS support is added, these fields will enable:
|
||||
// - Remote attestation verification via KBS
|
||||
// - Encrypted algorithm/dataset retrieval
|
||||
// - Per-computation secret provisioning
|
||||
//
|
||||
// Example future fields:
|
||||
// KBSEndpoint string `env:"KBS_ENDPOINT" envDefault:""` // Optional KBS URL
|
||||
// KBSEnabled bool `env:"KBS_ENABLED" envDefault:"false"`
|
||||
// KBSTimeout int `env:"KBS_TIMEOUT_SECONDS" envDefault:"30"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
+2
-2
@@ -122,7 +122,7 @@ func main() {
|
||||
defer cliSVC.Close()
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().BoolVarP(&cli.Verbose, "verbose", "v", false, "Enable verbose output")
|
||||
rootCmd.PersistentFlags().BoolVarP(&cliSVC.Verbose, "verbose", "v", false, "Enable verbose output")
|
||||
|
||||
keysCmd := cliSVC.NewKeysCmd()
|
||||
attestationCmd := cliSVC.NewAttestationCmd()
|
||||
@@ -151,7 +151,7 @@ func main() {
|
||||
|
||||
// Flags
|
||||
keysCmd.PersistentFlags().StringVarP(
|
||||
&cli.KeyType,
|
||||
&cliSVC.KeyType,
|
||||
"key-type",
|
||||
"k",
|
||||
"rsa",
|
||||
|
||||
+3
-1
@@ -56,7 +56,9 @@ func (s *SkopeoClient) PullAndDecrypt(ctx context.Context, source ResourceSource
|
||||
|
||||
args := []string{"copy"}
|
||||
|
||||
// Add decryption key if image is encrypted
|
||||
// Add decryption key if image is encrypted.
|
||||
// The KBS URL is configured at the CoCo Keyprovider service level
|
||||
// (via kernel cmdline agent.aa_kbc_params), not via ocicrypt's --decryption-key flag.
|
||||
if source.Encrypted {
|
||||
args = append(args, "--decryption-key", DecryptionKeyProvider)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,9 @@ type ResourceSource struct {
|
||||
// KBSResourcePath is the KBS resource path for the decryption key
|
||||
// (e.g., "default/key/algo-key")
|
||||
KBSResourcePath string `json:"kbs_resource_path,omitempty"`
|
||||
|
||||
// KBSURL is the KBS endpoint URL for this specific resource
|
||||
KBSURL string `json:"kbs_url,omitempty"`
|
||||
}
|
||||
|
||||
// ImageManifest represents basic OCI image manifest information.
|
||||
|
||||
+24
-15
@@ -43,10 +43,11 @@ var (
|
||||
pubKeyFile string
|
||||
clientCAFile string
|
||||
// Remote resource configuration.
|
||||
kbsURL string
|
||||
algoKBSURL string
|
||||
algoSourceURL string
|
||||
algoSourceType string
|
||||
algoKBSResourcePath string
|
||||
datasetKBSURLs string
|
||||
datasetSourceURLs string
|
||||
datasetSourceType string
|
||||
datasetKBSPaths string
|
||||
@@ -78,12 +79,16 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
// Check if using remote datasets
|
||||
var datasetURLs []string
|
||||
var datasetKBSPathsList []string
|
||||
var datasetKBSURLsList []string
|
||||
if datasetSourceURLs != "" {
|
||||
datasetURLs = strings.Split(datasetSourceURLs, ",")
|
||||
}
|
||||
if datasetKBSPaths != "" {
|
||||
datasetKBSPathsList = strings.Split(datasetKBSPaths, ",")
|
||||
}
|
||||
if datasetKBSURLs != "" {
|
||||
datasetKBSURLsList = strings.Split(datasetKBSURLs, ",")
|
||||
}
|
||||
|
||||
var datasetDecompressList []bool
|
||||
if datasetDecompress != "" {
|
||||
@@ -124,7 +129,7 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
if srcType == "" {
|
||||
srcType = "oci-image"
|
||||
}
|
||||
datasets = append(datasets, &cvms.Dataset{
|
||||
d := &cvms.Dataset{
|
||||
Hash: dataHashBytes,
|
||||
UserKey: pubPem.Bytes,
|
||||
Filename: fmt.Sprintf("dataset_%d.csv", i),
|
||||
@@ -134,7 +139,14 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
KbsResourcePath: datasetKBSPathsList[i],
|
||||
Encrypted: datasetKBSPathsList[i] != "",
|
||||
},
|
||||
})
|
||||
}
|
||||
if len(datasetKBSURLsList) > i && datasetKBSURLsList[i] != "" {
|
||||
d.Kbs = &cvms.KBSConfig{
|
||||
Url: datasetKBSURLsList[i],
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
datasets = append(datasets, d)
|
||||
if len(datasetDecompressList) > i {
|
||||
datasets[len(datasets)-1].Decompress = datasetDecompressList[i]
|
||||
}
|
||||
@@ -202,6 +214,12 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
Encrypted: algoKBSResourcePath != "",
|
||||
},
|
||||
}
|
||||
if algoKBSURL != "" {
|
||||
algorithm.Kbs = &cvms.KBSConfig{
|
||||
Url: algoKBSURL,
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Direct upload mode - use local file
|
||||
fileHash, err := internal.ChecksumHex(algoPath)
|
||||
@@ -225,15 +243,6 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
}
|
||||
}
|
||||
|
||||
// Build KBS config
|
||||
var kbsConfig *cvms.KBSConfig
|
||||
if kbsURL != "" {
|
||||
kbsConfig = &cvms.KBSConfig{
|
||||
Url: kbsURL,
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Debug("sending computation run request")
|
||||
if err := sendMessage(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_RunReq{
|
||||
@@ -249,7 +258,6 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
AttestedTls: attestedTLS,
|
||||
ClientCaFile: clientCAFile,
|
||||
},
|
||||
Kbs: kbsConfig,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
@@ -271,14 +279,15 @@ func main() {
|
||||
flagSet.StringVar(&dataPathString, "data-paths", "", "Paths to data sources, list of string separated with commas (for direct upload mode)")
|
||||
flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path")
|
||||
// Remote resource flags
|
||||
flagSet.StringVar(&kbsURL, "kbs-url", "", "KBS endpoint URL (e.g., 'http://localhost:8080')")
|
||||
flagSet.StringVar(&algoKBSURL, "algo-kbs-url", "", "Algorithm-specific KBS endpoint URL")
|
||||
flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (docker://..., s3://..., https://..., etc.)")
|
||||
flagSet.StringVar(&algoSourceType, "algo-source-type", "", "Algorithm source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.")
|
||||
flagSet.StringVar(&algoKBSResourcePath, "algo-kbs-path", "", "Algorithm KBS resource path (e.g., 'default/key/algo-key')")
|
||||
flagSet.StringVar(&datasetKBSURLs, "dataset-kbs-urls", "", "Dataset-specific KBS endpoint URLs, comma-separated")
|
||||
flagSet.StringVar(&datasetSourceURLs, "dataset-source-urls", "", "Dataset source URLs, comma-separated")
|
||||
flagSet.StringVar(&datasetSourceType, "dataset-source-type", "", "Dataset source type (oci-image, s3, gcs, https, http). Auto-detected from URL if empty.")
|
||||
flagSet.StringVar(&datasetKBSPaths, "dataset-kbs-paths", "", "Dataset KBS resource paths, comma-separated")
|
||||
flagSet.StringVar(&algoType, "algo-type", "", "Algorithm execution type (e.g. binary, python)")
|
||||
flagSet.StringVar(&algoType, "algo-type", "", "Algorithm execution type")
|
||||
flagSet.StringVar(&algoArgsString, "algo-args", "", "Algorithm arguments, comma-separated")
|
||||
flagSet.StringVar(&algoHash, "algo-hash", "", "Algorithm SHA256 hash (hex string)")
|
||||
flagSet.StringVar(&datasetTypeString, "dataset-type", "", "Dataset source type (deprecated, use --dataset-source-type)")
|
||||
|
||||
Reference in New Issue
Block a user