// Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 package agent import ( "context" "encoding/json" "fmt" "log/slog" "os" "os/exec" "path/filepath" "slices" "strings" sync "sync" "time" "github.com/absmach/supermq/pkg/errors" "github.com/ultravioletrs/cocos/agent/algorithm" "github.com/ultravioletrs/cocos/agent/events" runnerpb "github.com/ultravioletrs/cocos/agent/runner" "github.com/ultravioletrs/cocos/agent/statemachine" "github.com/ultravioletrs/cocos/internal" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner" "github.com/ultravioletrs/cocos/pkg/oci" "golang.org/x/crypto/sha3" ) var _ Service = (*agentService)(nil) //go:generate stringer -type=AgentState type AgentState int const ( Idle AgentState = iota ReceivingManifest ReceivingAlgorithm ReceivingData Running ConsumingResults Complete Failed ) //go:generate stringer -type=AgentEvent type AgentEvent int const ( Start AgentEvent = iota ManifestReceived AlgorithmReceived DataReceived RunComplete ResultsConsumed RunFailed ) //go:generate stringer -type=Status type Status uint8 const ( IdleState Status = iota InProgress Ready Completed Terminated Warning Starting ) const ( algoFilePermission = 0o700 ) var ( ImaMeasurementsFilePath = "/sys/kernel/security/integrity/ima/ascii_runtime_measurements" ImaPcrIndex = 10 ) var ( // ErrMalformedEntity indicates malformed entity specification (e.g. // invalid username or password). ErrMalformedEntity = errors.New("malformed entity specification") // ErrUnauthorizedAccess indicates missing or invalid credentials provided // when accessing a protected resource. ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided") // ErrUndeclaredAlgorithm indicates algorithm was not declared in computation manifest. ErrUndeclaredDataset = errors.New("dataset not declared in computation manifest") // ErrAllManifestItemsReceived indicates no new computation manifest items expected. ErrAllManifestItemsReceived = errors.New("all expected manifest Items have been received") // ErrUndeclaredConsumer indicates the consumer requesting results in not declared in computation manifest. ErrUndeclaredConsumer = errors.New("result consumer is undeclared in computation manifest") // ErrResultsNotReady indicates the computation results are not ready. ErrResultsNotReady = errors.New("computation results are not yet ready") // ErrStateNotReady agent received a request in the wrong state. ErrStateNotReady = errors.New("agent not expecting this operation in the current state") // ErrHashMismatch provided algorithm/dataset does not match hash in manifest. ErrHashMismatch = errors.New("malformed data, hash does not match manifest") // ErrFileNameMismatch provided dataset filename does not match filename in manifest. ErrFileNameMismatch = errors.New("malformed data, filename does not match manifest") // ErrAllResultsConsumed indicates all results have been consumed. ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers") // ErrAttestationFailed attestation failed. ErrAttestationFailed = errors.New("failed to get raw quote") // ErrAttestationVTpmFailed vTPM attestation failed. ErrAttestationVTpmFailed = errors.New("failed to get vTPM quote") // ErrFetchAzureToken azure token fetch failed. ErrFetchAzureToken = errors.New("failed to get azure token") // ErrAttType indicates that the attestation type that is requested does not exist or is not supported. ErrAttestationType = errors.New("attestation type does not exist or is not supported") ) // Service specifies an API that must be fullfiled by the domain service // implementation, and all of its decorators (e.g. logging & metrics). type Service interface { InitComputation(ctx context.Context, cmp Computation) error StopComputation(ctx context.Context) error Algo(ctx context.Context, algorithm Algorithm) error Data(ctx context.Context, dataset Dataset) error Result(ctx context.Context) ([]byte, error) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error) State() string } type OCIClient interface { PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error } type agentService struct { mu sync.Mutex computation Computation // Holds the current computation request details. runnerClient runner_client.Client algoType string algoArgs []string algoRequirements []byte algoReceived bool result []byte // Stores the result of the computation. sm statemachine.StateMachine // Manages the state transitions of the agent service. runError error // Stores any error encountered during the computation run. eventSvc events.Service // Service for publishing events related to computation. attestationClient attestation_client.Client // Client for attestation service. logger *slog.Logger // Logger for the agent service. resultsConsumed bool // Indicates if the results have been consumed. cancel context.CancelFunc // Cancels the computation context. vmpl int // VMPL at which the Agent is running. ociClient OCIClient } var _ Service = (*agentService)(nil) // New instantiates the agent service implementation. func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, runnerClient runner_client.Client, vmlp int) Service { sm := statemachine.NewStateMachine(Idle) ctx, cancel := context.WithCancel(ctx) svc := &agentService{ sm: sm, eventSvc: eventSvc, attestationClient: attestationClient, runnerClient: runnerClient, logger: logger, cancel: cancel, vmpl: vmlp, } workDir := filepath.Join(os.TempDir(), "cocos-oci") skopeoClient, err := oci.NewSkopeoClient(workDir) if err != nil { logger.Warn("failed to create Skopeo client", "error", err) } svc.ociClient = skopeoClient transitions := []statemachine.Transition{ {From: Idle, Event: Start, To: ReceivingManifest}, {From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm}, } transitions = append(transitions, []statemachine.Transition{ {From: ReceivingAlgorithm, Event: RunFailed, To: Failed}, {From: ReceivingData, Event: RunFailed, To: Failed}, {From: Running, Event: RunComplete, To: ConsumingResults}, {From: Running, Event: RunFailed, To: Failed}, {From: ConsumingResults, Event: ResultsConsumed, To: Complete}, }...) for _, t := range transitions { sm.AddTransition(t) } sm.SetAction(ReceivingAlgorithm, svc.downloadAlgorithmIfRemote) sm.SetAction(ReceivingData, svc.downloadDatasetsIfRemote) sm.SetAction(Running, svc.runComputation) sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String())) sm.SetAction(Complete, svc.publishEvent(Completed.String())) sm.SetAction(Failed, svc.publishEvent(Failed.String())) go func() { if err := sm.Start(ctx); err != nil { logger.Error(err.Error()) } }() time.Sleep(100 * time.Millisecond) sm.SendEvent(Start) time.Sleep(100 * time.Millisecond) return svc } func (as *agentService) State() string { return as.sm.GetState().String() } func (as *agentService) InitComputation(ctx context.Context, cmp Computation) error { if as.sm.GetState() != ReceivingManifest { return ErrStateNotReady } defer as.sm.SendEvent(ManifestReceived) as.mu.Lock() defer as.mu.Unlock() as.computation = cmp // Debug: Log manifest details as.logger.Info("received computation manifest", "computation_id", cmp.ID, "kbs_enabled", cmp.KBS.Enabled, "kbs_url", cmp.KBS.URL, "algo_has_source", cmp.Algorithm.Source != nil, "dataset_count", len(cmp.Datasets)) if cmp.Algorithm.Source != nil { as.logger.Info("algorithm remote source configured", "url", cmp.Algorithm.Source.URL, "kbs_resource_path", cmp.Algorithm.Source.KBSResourcePath) } else { as.logger.Info("algorithm remote source NOT configured - will wait for direct upload") } if cmp.KBS.Enabled { as.logger.Info("KBS is ENABLED", "url", cmp.KBS.URL) } else { as.logger.Info("KBS is NOT ENABLED") } for i, d := range cmp.Datasets { if d.Source != nil { as.logger.Info("dataset remote source configured", "index", i, "filename", d.Filename, "url", d.Source.URL, "kbs_resource_path", d.Source.KBSResourcePath) } } transitions := []statemachine.Transition{} if len(cmp.Datasets) == 0 { transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running}) } else { transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData}) transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running}) } for _, t := range transitions { as.sm.AddTransition(t) } return nil } func (as *agentService) StopComputation(ctx context.Context) error { as.mu.Lock() defer as.mu.Unlock() as.eventSvc.SendEvent(as.computation.ID, "Stopped", "Stopped", json.RawMessage{}) as.cancel() if _, err := as.runnerClient.Stop(ctx, &runnerpb.StopRequest{ComputationId: as.computation.ID}); err != nil { as.logger.Warn("failed to stop runner", "error", err) // proceed to cleanup } if err := os.RemoveAll(algorithm.DatasetsDir); err != nil { return fmt.Errorf("error removing datasets directory: %v", err) } if err := os.RemoveAll(algorithm.ResultsDir); err != nil { return fmt.Errorf("error removing results directory: %v", err) } as.sm.Reset(Idle) as.computation = Computation{} as.algoReceived = false as.algoType = "" as.algoArgs = nil as.algoRequirements = nil as.result = nil as.runError = nil as.resultsConsumed = false ctx, cancel := context.WithCancel(ctx) as.cancel = cancel go func() { if err := as.sm.Start(ctx); err != nil { as.logger.Error(err.Error()) } }() time.Sleep(100 * time.Millisecond) as.sm.SendEvent(Start) time.Sleep(100 * time.Millisecond) return nil } // downloadAlgorithmIfRemote automatically downloads the algorithm if it has a remote source. // This is called as an action when entering the ReceivingAlgorithm state. func (as *agentService) downloadAlgorithmIfRemote(state statemachine.State) { as.publishEvent(InProgress.String())(state) as.mu.Lock() defer as.mu.Unlock() // Debug: Log decision point as.logger.Info("checking if algorithm should be downloaded automatically", "algo_has_source", as.computation.Algorithm.Source != nil, "kbs_enabled", as.computation.KBS.Enabled) // Check if algorithm should be downloaded from remote source if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled { as.logger.Info("downloading algorithm from remote source", "url", as.computation.Algorithm.Source.URL, "kbs_resource_path", as.computation.Algorithm.Source.KBSResourcePath) // Use background context for download operation ctx := context.Background() res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm") if err != nil { as.runError = fmt.Errorf("failed to download and decrypt algorithm: %w", err) as.logger.Error(as.runError.Error()) as.sm.SendEvent(RunFailed) return } // Verify hash hash := sha3.Sum256(res.Data) if hash != as.computation.Algorithm.Hash { as.runError = fmt.Errorf("algorithm hash mismatch: expected %x, got %x", as.computation.Algorithm.Hash, hash) as.logger.Error(as.runError.Error()) as.sm.SendEvent(RunFailed) return } // Write algorithm to file currentDir, err := os.Getwd() if err != nil { as.runError = fmt.Errorf("error getting current directory: %w", err) as.logger.Error(as.runError.Error()) as.sm.SendEvent(RunFailed) return } // If a source directory is available (e.g. from OCI extraction), copy all files if res.SourceDir != "" { as.logger.Info("copying extracted algorithm directory", "src", res.SourceDir, "dst", currentDir) // Simple recursive copy (using shell cp for simplicity and reliability on Linux) // Ensure we copy contents of SourceDir into currentDir // Simple recursive copy (using shell cp for simplicity and reliability on Linux) // Ensure we copy contents of SourceDir into currentDir cmd := exec.Command("cp", "-r", res.SourceDir+"/.", currentDir) if out, err := cmd.CombinedOutput(); err != nil { as.runError = fmt.Errorf("error copying algorithm directory: %v, output: %s", err, out) as.logger.Error(as.runError.Error()) as.sm.SendEvent(RunFailed) return } } f, err := os.Create(filepath.Join(currentDir, "algo")) if err != nil { as.runError = fmt.Errorf("error creating algorithm file: %w", err) as.logger.Error(as.runError.Error()) as.sm.SendEvent(RunFailed) return } if _, err := f.Write(res.Data); err != nil { as.runError = fmt.Errorf("error writing algorithm to file: %w", err) as.logger.Error(as.runError.Error()) f.Close() as.sm.SendEvent(RunFailed) return } if err := os.Chmod(f.Name(), algoFilePermission); err != nil { as.runError = fmt.Errorf("error changing file permissions: %w", err) as.logger.Error(as.runError.Error()) f.Close() as.sm.SendEvent(RunFailed) return } if err := f.Close(); err != nil { as.runError = fmt.Errorf("error closing file: %w", err) as.logger.Error(as.runError.Error()) as.sm.SendEvent(RunFailed) return } as.algoReceived = true as.algoRequirements = res.Requirements // Store requirements for installation // Create datasets directory if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil { as.runError = fmt.Errorf("error creating datasets directory: %w", err) as.logger.Error(as.runError.Error()) as.sm.SendEvent(RunFailed) return } as.algoType = as.computation.Algorithm.AlgoType if as.algoType == "" { as.algoType = string(algorithm.AlgoTypeBin) } as.algoArgs = as.computation.Algorithm.AlgoArgs as.logger.Info("algorithm downloaded and saved successfully", "type", as.algoType, "has_requirements", len(res.Requirements) > 0) as.sm.SendEvent(AlgorithmReceived) } else { // If no remote source, do nothing - wait for direct upload via Algo() RPC call as.logger.Info("algorithm automatic download not triggered, waiting for direct upload", "reason", "no remote source or KBS not enabled") } } // downloadDatasetsIfRemote automatically downloads datasets that have remote sources. // This is called as an action when entering the ReceivingData state. func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) { as.publishEvent(InProgress.String())(state) as.mu.Lock() defer as.mu.Unlock() // Check if any datasets should be downloaded from remote sources hasRemoteDatasets := false for _, d := range as.computation.Datasets { if d.Source != nil && as.computation.KBS.Enabled { hasRemoteDatasets = true break } } if !hasRemoteDatasets { // No remote datasets, wait for direct uploads via Data() RPC calls return } // Download all remote datasets ctx := context.Background() for i := len(as.computation.Datasets) - 1; i >= 0; i-- { d := as.computation.Datasets[i] if d.Source != nil && as.computation.KBS.Enabled { as.logger.Info("downloading dataset from remote source", "filename", d.Filename) res, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset") if err != nil { as.logger.Error("failed to download and decrypt dataset", "error", err, "filename", d.Filename) as.sm.SendEvent(RunFailed) return } // Verify hash hash := sha3.Sum256(res.Data) if hash != d.Hash { as.logger.Error("dataset hash mismatch", "filename", d.Filename) as.sm.SendEvent(RunFailed) return } // Write dataset to file f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, d.Filename)) if err != nil { as.logger.Error("error creating dataset file", "error", err, "filename", d.Filename) as.sm.SendEvent(RunFailed) return } if d.Decompress { if err := internal.UnzipFromMemory(res.Data, algorithm.DatasetsDir); err != nil { as.logger.Error("error decompressing dataset", "error", err, "filename", d.Filename) as.sm.SendEvent(RunFailed) return } } else { if _, err := f.Write(res.Data); err != nil { as.logger.Error("error writing dataset to file", "error", err, "filename", d.Filename) f.Close() as.sm.SendEvent(RunFailed) return } } if err := f.Close(); err != nil { as.logger.Error("error closing file", "error", err, "filename", d.Filename) as.sm.SendEvent(RunFailed) return } // Remove from pending datasets as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1) as.logger.Info("dataset downloaded and saved successfully", "filename", d.Filename) } } // If all datasets are downloaded, send DataReceived event if len(as.computation.Datasets) == 0 { as.logger.Info("all datasets downloaded successfully") as.sm.SendEvent(DataReceived) } // Otherwise, wait for remaining datasets to be uploaded via Data() RPC calls } // DecryptedResource holds the data and metadata of a downloaded and decrypted resource. type DecryptedResource struct { Data []byte Requirements []byte SourceDir string } // downloadAndDecryptResource downloads and decrypts a resource using OCI images and CoCo Keyprovider. // For OCI images, Skopeo handles download and CoCo Keyprovider handles decryption automatically. func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) { // Determine source type sourceType := source.Type if sourceType == "" { // Infer from URL if strings.HasPrefix(source.URL, "docker://") || strings.HasPrefix(source.URL, "oci:") { sourceType = "oci-image" } else { return nil, fmt.Errorf("unsupported source URL format: %s (use oci-image type)", source.URL) } } switch sourceType { case "oci-image": return as.downloadAndDecryptOCIImage(ctx, source, resourceType) default: return nil, fmt.Errorf("unsupported source type: %s", sourceType) } } // downloadAndDecryptOCIImage downloads and decrypts an OCI image using Skopeo and CoCo Keyprovider. func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) { as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s)", source.URL, source.Encrypted, source.KBSResourcePath)) // Create Skopeo client if as.ociClient == nil { return nil, fmt.Errorf("OCI client not initialized") } // Create OCI resource source ociSource := oci.ResourceSource{ Type: oci.ResourceTypeOCIImage, URI: source.URL, Encrypted: source.Encrypted, KBSResourcePath: source.KBSResourcePath, } // Pull and decrypt image // CoCo Keyprovider will automatically handle decryption via ocicrypt // Sanitize directory name to avoid Skopeo interpreting ':' as tag separator sanitizedName := strings.ReplaceAll(filepath.Base(source.URL), ":", "_") destDir := filepath.Join(os.TempDir(), "cocos-oci", "images", sanitizedName) if err := as.ociClient.PullAndDecrypt(ctx, ociSource, destDir); err != nil { return nil, fmt.Errorf("failed to pull and decrypt OCI image: %w", err) } as.logger.Info("OCI image downloaded and decrypted", "dest", destDir) // Extract algorithm file from OCI layers extractDir := filepath.Join(os.TempDir(), "cocos-oci", "extracted", sanitizedName) var algorithmPath string var err error if resourceType == "algorithm" { algorithmPath, err = oci.ExtractAlgorithm(ctx, as.logger, destDir, extractDir) if err != nil { return nil, fmt.Errorf("failed to extract algorithm from OCI image: %w", err) } as.logger.Info("algorithm extracted from OCI image", "path", algorithmPath) } else { // Assume dataset files, err := oci.ExtractDataset(destDir, extractDir) if err != nil || len(files) == 0 { return nil, fmt.Errorf("failed to extract dataset from OCI image: %w", err) } // For now, take the first file found. // nolint:godox // TODO: Handle multiple files / directory structure if needed. algorithmPath = files[0] as.logger.Info("dataset extracted from OCI image", "path", algorithmPath) } // Read algorithm file algorithmData, err := os.ReadFile(algorithmPath) if err != nil { return nil, fmt.Errorf("failed to read algorithm file: %w", err) } // Check for requirements.txt if algorithm var reqData []byte if resourceType == "algorithm" { reqPath := filepath.Join(filepath.Dir(algorithmPath), "requirements.txt") if data, err := os.ReadFile(reqPath); err == nil { reqData = data as.logger.Info("found requirements.txt", "size", len(data)) } } as.logger.Info("algorithm loaded", "size", len(algorithmData)) return &DecryptedResource{ Data: algorithmData, Requirements: reqData, SourceDir: filepath.Dir(algorithmPath), }, nil } func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { if as.sm.GetState() != ReceivingAlgorithm { return ErrStateNotReady } as.mu.Lock() defer as.mu.Unlock() if as.algoReceived { return ErrAllManifestItemsReceived } var algoData []byte // Check if algorithm should be downloaded from remote source if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled { as.logger.Info("downloading algorithm from remote source") res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm") if err != nil { return fmt.Errorf("failed to download and decrypt algorithm: %w", err) } algoData = res.Data as.algoRequirements = res.Requirements } else { // Use directly uploaded algorithm algoData = algo.Algorithm } hash := sha3.Sum256(algoData) if hash != as.computation.Algorithm.Hash { return ErrHashMismatch } currentDir, err := os.Getwd() if err != nil { return fmt.Errorf("error getting current directory: %v", err) } f, err := os.Create(filepath.Join(currentDir, "algo")) if err != nil { return fmt.Errorf("error creating algorithm file: %v", err) } if _, err := f.Write(algoData); err != nil { return fmt.Errorf("error writing algorithm to file: %v", err) } if err := os.Chmod(f.Name(), algoFilePermission); err != nil { return fmt.Errorf("error changing file permissions: %v", err) } if err := f.Close(); err != nil { return fmt.Errorf("error closing file: %v", err) } algoType := algorithm.AlgorithmTypeFromContext(ctx) if algoType == "" { algoType = string(algorithm.AlgoTypeBin) } args := algorithm.AlgorithmArgsFromContext(ctx) as.algoType = algoType as.algoArgs = args as.algoRequirements = algo.Requirements as.algoReceived = true if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil { return fmt.Errorf("error creating datasets directory: %v", err) } if as.algoReceived { as.sm.SendEvent(AlgorithmReceived) } return nil } func (as *agentService) Data(ctx context.Context, dataset Dataset) error { if as.sm.GetState() != ReceivingData { return ErrStateNotReady } as.mu.Lock() defer as.mu.Unlock() if len(as.computation.Datasets) == 0 { return ErrAllManifestItemsReceived } var datasetData []byte var datasetFilename string // Check if any dataset should be downloaded from remote source matchedIndex := -1 for i, d := range as.computation.Datasets { if d.Source != nil && as.computation.KBS.Enabled { as.logger.Info("downloading dataset from remote source", "filename", d.Filename) downloadedData, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset") if err != nil { return fmt.Errorf("failed to download and decrypt dataset: %w", err) } datasetData = downloadedData.Data datasetFilename = d.Filename matchedIndex = i break } } // If no remote dataset, use uploaded dataset if matchedIndex == -1 { datasetData = dataset.Dataset datasetFilename = dataset.Filename } hash := sha3.Sum256(datasetData) matched := false for i, d := range as.computation.Datasets { if hash == d.Hash { if d.Filename != "" && d.Filename != datasetFilename { return ErrFileNameMismatch } as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1) if DecompressFromContext(ctx) { if err := internal.UnzipFromMemory(datasetData, algorithm.DatasetsDir); err != nil { return fmt.Errorf("error decompressing dataset: %v", err) } } else { f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, datasetFilename)) if err != nil { return fmt.Errorf("error creating dataset file: %v", err) } if _, err := f.Write(datasetData); err != nil { return fmt.Errorf("error writing dataset to file: %v", err) } if err := f.Close(); err != nil { return fmt.Errorf("error closing file: %v", err) } } matched = true break } } if !matched { return ErrUndeclaredDataset } if len(as.computation.Datasets) == 0 { defer as.sm.SendEvent(DataReceived) } return nil } func (as *agentService) Result(ctx context.Context) ([]byte, error) { currentState := as.sm.GetState() if currentState != ConsumingResults && currentState != Complete && currentState != Failed { return []byte{}, ErrResultsNotReady } index, ok := IndexFromContext(ctx) if !ok { return []byte{}, ErrUndeclaredConsumer } as.mu.Lock() defer as.mu.Unlock() if index < 0 || index >= len(as.computation.ResultConsumers) { return []byte{}, ErrUndeclaredConsumer } if !as.resultsConsumed && currentState == ConsumingResults { as.resultsConsumed = true defer as.sm.SendEvent(ResultsConsumed) } return as.result, as.runError } func (as *agentService) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { rawQuote, err := as.attestationClient.GetAttestation(ctx, reportData, nonce, attType) if err != nil { return []byte{}, errors.Wrap(ErrAttestationFailed, err) } return rawQuote, nil } func (as *agentService) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error) { token, err := as.attestationClient.GetAzureToken(ctx, nonce) if err != nil { return []byte{}, err } return token, nil } func (as *agentService) runComputation(state statemachine.State) { as.publishEvent(Starting.String())(state) as.logger.Debug("computation run started") defer func() { if as.runError != nil { as.sm.SendEvent(RunFailed) } else { as.sm.SendEvent(RunComplete) } }() if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil { as.runError = fmt.Errorf("error creating results directory: %s", err.Error()) as.logger.Warn(as.runError.Error()) as.publishEvent(Failed.String())(state) return } defer func() { if err := os.RemoveAll(algorithm.ResultsDir); err != nil { as.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error())) } if err := os.RemoveAll(algorithm.DatasetsDir); err != nil { as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error())) } }() // Read algo file currentDir, _ := os.Getwd() algoFile := filepath.Join(currentDir, "algo") algoBytes, err := os.ReadFile(algoFile) if err != nil { as.runError = fmt.Errorf("failed to read algo file: %w", err) as.logger.Warn(as.runError.Error()) as.publishEvent(Failed.String())(state) return } as.publishEvent(InProgress.String())(state) // Call Runner resp, err := as.runnerClient.Run(context.Background(), &runnerpb.RunRequest{ ComputationId: as.computation.ID, AlgoType: as.algoType, Algorithm: algoBytes, Requirements: as.algoRequirements, Args: as.algoArgs, // Datasets implicit on shared FS }) if err != nil { as.runError = err as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error())) as.publishEvent(Failed.String())(state) return } if resp.Error != "" { as.runError = errors.New(resp.Error) as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error)) as.publishEvent(Failed.String())(state) return } results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir) if err != nil { as.runError = err as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error())) as.publishEvent(Failed.String())(state) return } as.publishEvent(Completed.String())(state) as.result = results } func (as *agentService) publishEvent(status string) statemachine.Action { return func(state statemachine.State) { as.eventSvc.SendEvent(as.computation.ID, state.String(), status, json.RawMessage{}) } } func (as *agentService) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) { data, err := os.ReadFile(ImaMeasurementsFilePath) if err != nil { return nil, nil, fmt.Errorf("error reading Linux IMA measurements file: %s", err.Error()) } pcr10, err := vtpm.GetPCRSHA1Value(ImaPcrIndex) if err != nil { return nil, nil, fmt.Errorf("error reading TPM PCR #10: %s", err.Error()) } return data, pcr10, nil }