Files
cocos/agent/service.go
T
Sammy Kerata Oina 207bfd99af COCOS-525-487 - Refactor attestation and atls (#562)
* Refactor attestation handling to remove quoteprovider dependency

- Removed references to quoteprovider in various files, replacing them with vtpm where necessary.
- Updated function signatures and implementations to use SEVNonce instead of quoteprovider.Nonce.
- Introduced new vtpm package to handle SEV-related attestation logic, including fetching and verifying attestation reports.
- Adjusted tests to reflect changes in the attestation logic and ensure compatibility with the new structure.
- Deleted the now redundant quoteprovider/sev_test.go file.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Add veraison/go-cose dependency to go.mod

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Introduce TLS package for enhanced security configuration and refactor client code to utilize new TLS utilities

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2026-02-18 11:53:04 +01:00

531 lines
16 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package agent
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"slices"
sync "sync"
"time"
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/events"
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
"github.com/ultravioletrs/cocos/agent/statemachine"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
"golang.org/x/crypto/sha3"
)
var _ Service = (*agentService)(nil)
//go:generate stringer -type=AgentState
type AgentState int
const (
Idle AgentState = iota
ReceivingManifest
ReceivingAlgorithm
ReceivingData
Running
ConsumingResults
Complete
Failed
)
//go:generate stringer -type=AgentEvent
type AgentEvent int
const (
Start AgentEvent = iota
ManifestReceived
AlgorithmReceived
DataReceived
RunComplete
ResultsConsumed
RunFailed
)
//go:generate stringer -type=Status
type Status uint8
const (
IdleState Status = iota
InProgress
Ready
Completed
Terminated
Warning
Starting
)
const (
algoFilePermission = 0o700
)
const (
ImaMeasurementsFilePath = "/sys/kernel/security/integrity/ima/ascii_runtime_measurements"
ImaPcrIndex = 10
)
var (
// ErrMalformedEntity indicates malformed entity specification (e.g.
// invalid username or password).
ErrMalformedEntity = errors.New("malformed entity specification")
// ErrUnauthorizedAccess indicates missing or invalid credentials provided
// when accessing a protected resource.
ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided")
// ErrUndeclaredAlgorithm indicates algorithm was not declared in computation manifest.
ErrUndeclaredDataset = errors.New("dataset not declared in computation manifest")
// ErrAllManifestItemsReceived indicates no new computation manifest items expected.
ErrAllManifestItemsReceived = errors.New("all expected manifest Items have been received")
// ErrUndeclaredConsumer indicates the consumer requesting results in not declared in computation manifest.
ErrUndeclaredConsumer = errors.New("result consumer is undeclared in computation manifest")
// ErrResultsNotReady indicates the computation results are not ready.
ErrResultsNotReady = errors.New("computation results are not yet ready")
// ErrStateNotReady agent received a request in the wrong state.
ErrStateNotReady = errors.New("agent not expecting this operation in the current state")
// ErrHashMismatch provided algorithm/dataset does not match hash in manifest.
ErrHashMismatch = errors.New("malformed data, hash does not match manifest")
// ErrFileNameMismatch provided dataset filename does not match filename in manifest.
ErrFileNameMismatch = errors.New("malformed data, filename does not match manifest")
// ErrAllResultsConsumed indicates all results have been consumed.
ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers")
// ErrAttestationFailed attestation failed.
ErrAttestationFailed = errors.New("failed to get raw quote")
// ErrAttestationVTpmFailed vTPM attestation failed.
ErrAttestationVTpmFailed = errors.New("failed to get vTPM quote")
// ErrFetchAzureToken azure token fetch failed.
ErrFetchAzureToken = errors.New("failed to get azure token")
// ErrAttType indicates that the attestation type that is requested does not exist or is not supported.
ErrAttestationType = errors.New("attestation type does not exist or is not supported")
)
// Service specifies an API that must be fullfiled by the domain service
// implementation, and all of its decorators (e.g. logging & metrics).
type Service interface {
InitComputation(ctx context.Context, cmp Computation) error
StopComputation(ctx context.Context) error
Algo(ctx context.Context, algorithm Algorithm) error
Data(ctx context.Context, dataset Dataset) error
Result(ctx context.Context) ([]byte, error)
Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error)
IMAMeasurements(ctx context.Context) ([]byte, []byte, error)
AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error)
State() string
}
type agentService struct {
mu sync.Mutex
computation Computation // Holds the current computation request details.
runnerClient runner_client.Client
algoType string
algoArgs []string
algoRequirements []byte
algoReceived bool
result []byte // Stores the result of the computation.
sm statemachine.StateMachine // Manages the state transitions of the agent service.
runError error // Stores any error encountered during the computation run.
eventSvc events.Service // Service for publishing events related to computation.
attestationClient attestation_client.Client // Client for attestation service.
logger *slog.Logger // Logger for the agent service.
resultsConsumed bool // Indicates if the results have been consumed.
cancel context.CancelFunc // Cancels the computation context.
vmpl int // VMPL at which the Agent is running.
}
var _ Service = (*agentService)(nil)
// New instantiates the agent service implementation.
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, runnerClient runner_client.Client, vmlp int) Service {
sm := statemachine.NewStateMachine(Idle)
ctx, cancel := context.WithCancel(ctx)
svc := &agentService{
sm: sm,
eventSvc: eventSvc,
attestationClient: attestationClient,
runnerClient: runnerClient,
logger: logger,
cancel: cancel,
vmpl: vmlp,
}
transitions := []statemachine.Transition{
{From: Idle, Event: Start, To: ReceivingManifest},
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
}
transitions = append(transitions, []statemachine.Transition{
{From: Running, Event: RunComplete, To: ConsumingResults},
{From: Running, Event: RunFailed, To: Failed},
{From: ConsumingResults, Event: ResultsConsumed, To: Complete},
}...)
for _, t := range transitions {
sm.AddTransition(t)
}
sm.SetAction(ReceivingAlgorithm, svc.publishEvent(InProgress.String()))
sm.SetAction(ReceivingData, svc.publishEvent(InProgress.String()))
sm.SetAction(Running, svc.runComputation)
sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String()))
sm.SetAction(Complete, svc.publishEvent(Completed.String()))
sm.SetAction(Failed, svc.publishEvent(Failed.String()))
go func() {
if err := sm.Start(ctx); err != nil {
logger.Error(err.Error())
}
}()
time.Sleep(100 * time.Millisecond)
sm.SendEvent(Start)
time.Sleep(100 * time.Millisecond)
return svc
}
func (as *agentService) State() string {
return as.sm.GetState().String()
}
func (as *agentService) InitComputation(ctx context.Context, cmp Computation) error {
if as.sm.GetState() != ReceivingManifest {
return ErrStateNotReady
}
defer as.sm.SendEvent(ManifestReceived)
as.mu.Lock()
defer as.mu.Unlock()
as.computation = cmp
transitions := []statemachine.Transition{}
if len(cmp.Datasets) == 0 {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
} else {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
}
for _, t := range transitions {
as.sm.AddTransition(t)
}
return nil
}
func (as *agentService) StopComputation(ctx context.Context) error {
as.mu.Lock()
defer as.mu.Unlock()
as.eventSvc.SendEvent(as.computation.ID, "Stopped", "Stopped", json.RawMessage{})
as.cancel()
if _, err := as.runnerClient.Stop(ctx, &runnerpb.StopRequest{ComputationId: as.computation.ID}); err != nil {
as.logger.Warn("failed to stop runner", "error", err)
// proceed to cleanup
}
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
return fmt.Errorf("error removing datasets directory: %v", err)
}
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
return fmt.Errorf("error removing results directory: %v", err)
}
as.sm.Reset(Idle)
as.computation = Computation{}
as.algoReceived = false
as.algoType = ""
as.algoArgs = nil
as.algoRequirements = nil
as.result = nil
as.runError = nil
as.resultsConsumed = false
ctx, cancel := context.WithCancel(ctx)
as.cancel = cancel
go func() {
if err := as.sm.Start(ctx); err != nil {
as.logger.Error(err.Error())
}
}()
time.Sleep(100 * time.Millisecond)
as.sm.SendEvent(Start)
time.Sleep(100 * time.Millisecond)
return nil
}
func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
if as.sm.GetState() != ReceivingAlgorithm {
return ErrStateNotReady
}
as.mu.Lock()
defer as.mu.Unlock()
if as.algoReceived {
return ErrAllManifestItemsReceived
}
hash := sha3.Sum256(algo.Algorithm)
if hash != as.computation.Algorithm.Hash {
return ErrHashMismatch
}
currentDir, err := os.Getwd()
if err != nil {
return fmt.Errorf("error getting current directory: %v", err)
}
f, err := os.Create(filepath.Join(currentDir, "algo"))
if err != nil {
return fmt.Errorf("error creating algorithm file: %v", err)
}
if _, err := f.Write(algo.Algorithm); err != nil {
return fmt.Errorf("error writing algorithm to file: %v", err)
}
if err := os.Chmod(f.Name(), algoFilePermission); err != nil {
return fmt.Errorf("error changing file permissions: %v", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("error closing file: %v", err)
}
algoType := algorithm.AlgorithmTypeFromContext(ctx)
if algoType == "" {
algoType = string(algorithm.AlgoTypeBin)
}
args := algorithm.AlgorithmArgsFromContext(ctx)
as.algoType = algoType
as.algoArgs = args
as.algoRequirements = algo.Requirements
as.algoReceived = true
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
return fmt.Errorf("error creating datasets directory: %v", err)
}
if as.algoReceived {
as.sm.SendEvent(AlgorithmReceived)
}
return nil
}
func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
if as.sm.GetState() != ReceivingData {
return ErrStateNotReady
}
as.mu.Lock()
defer as.mu.Unlock()
if len(as.computation.Datasets) == 0 {
return ErrAllManifestItemsReceived
}
hash := sha3.Sum256(dataset.Dataset)
matched := false
for i, d := range as.computation.Datasets {
if hash == d.Hash {
if d.Filename != "" && d.Filename != dataset.Filename {
return ErrFileNameMismatch
}
as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1)
if DecompressFromContext(ctx) {
if err := internal.UnzipFromMemory(dataset.Dataset, algorithm.DatasetsDir); err != nil {
return fmt.Errorf("error decompressing dataset: %v", err)
}
} else {
f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, dataset.Filename))
if err != nil {
return fmt.Errorf("error creating dataset file: %v", err)
}
if _, err := f.Write(dataset.Dataset); err != nil {
return fmt.Errorf("error writing dataset to file: %v", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("error closing file: %v", err)
}
}
matched = true
break
}
}
if !matched {
return ErrUndeclaredDataset
}
if len(as.computation.Datasets) == 0 {
defer as.sm.SendEvent(DataReceived)
}
return nil
}
func (as *agentService) Result(ctx context.Context) ([]byte, error) {
currentState := as.sm.GetState()
if currentState != ConsumingResults && currentState != Complete && currentState != Failed {
return []byte{}, ErrResultsNotReady
}
index, ok := IndexFromContext(ctx)
if !ok {
return []byte{}, ErrUndeclaredConsumer
}
as.mu.Lock()
defer as.mu.Unlock()
if index < 0 || index >= len(as.computation.ResultConsumers) {
return []byte{}, ErrUndeclaredConsumer
}
if !as.resultsConsumed && currentState == ConsumingResults {
as.resultsConsumed = true
defer as.sm.SendEvent(ResultsConsumed)
}
return as.result, as.runError
}
func (as *agentService) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
rawQuote, err := as.attestationClient.GetAttestation(ctx, reportData, nonce, attType)
if err != nil {
return []byte{}, errors.Wrap(ErrAttestationFailed, err)
}
return rawQuote, nil
}
func (as *agentService) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error) {
token, err := as.attestationClient.GetAzureToken(ctx, nonce)
if err != nil {
return []byte{}, err
}
return token, nil
}
func (as *agentService) runComputation(state statemachine.State) {
as.publishEvent(Starting.String())(state)
as.logger.Debug("computation run started")
defer func() {
if as.runError != nil {
as.sm.SendEvent(RunFailed)
} else {
as.sm.SendEvent(RunComplete)
}
}()
if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil {
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
}
defer func() {
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
as.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error()))
}
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
}
}()
// Read algo file
currentDir, _ := os.Getwd()
algoFile := filepath.Join(currentDir, "algo")
algoBytes, err := os.ReadFile(algoFile)
if err != nil {
as.runError = fmt.Errorf("failed to read algo file: %w", err)
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
}
as.publishEvent(InProgress.String())(state)
// Call Runner
resp, err := as.runnerClient.Run(context.Background(), &runnerpb.RunRequest{
ComputationId: as.computation.ID,
AlgoType: as.algoType,
Algorithm: algoBytes,
Requirements: as.algoRequirements,
Args: as.algoArgs,
// Datasets implicit on shared FS
})
if err != nil {
as.runError = err
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
}
if resp.Error != "" {
as.runError = errors.New(resp.Error)
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error))
as.publishEvent(Failed.String())(state)
return
}
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
if err != nil {
as.runError = err
as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
}
as.publishEvent(Completed.String())(state)
as.result = results
}
func (as *agentService) publishEvent(status string) statemachine.Action {
return func(state statemachine.State) {
as.eventSvc.SendEvent(as.computation.ID, state.String(), status, json.RawMessage{})
}
}
func (as *agentService) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) {
data, err := os.ReadFile(ImaMeasurementsFilePath)
if err != nil {
return nil, nil, fmt.Errorf("error reading Linux IMA measurements file: %s", err.Error())
}
pcr10, err := vtpm.GetPCRSHA1Value(ImaPcrIndex)
if err != nil {
return nil, nil, fmt.Errorf("error reading TPM PCR #10: %s", err.Error())
}
return data, pcr10, nil
}