Files
cocos/manager/service.go
T
Sammy Kerata Oina c1cbcec851
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
COCOS-577 - Introduce Go-based CoRIM generation and deprecate Rust attestation policy scripts. (#578)
* feat: Introduce Go-based CoRIM generation and deprecate Rust attestation policy scripts.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update dependencies and refactor attestation policy handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Migrate attestation verification to use CoRIM and remove deprecated policy handling and EAT verification tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Removed the `tdx` and `sev-snp` attestation policy scripts and their build configurations, along with related build and installation steps from the main Makefile.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Remove Rust CI workflow and Cargo Dependabot configuration, and enhance Go test setup for attestation policy paths.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Use WriteString instead of Write([]byte) for writing policy file content in test.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Refactor `ca-bundle` command to fetch bundles by product string using a configurable HTTP getter with improved error handling, and simplify `attestation_policy` command usage.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: ignore return value of cmd.Help()

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Implement CoRIM generation for Azure and GCP attestation policies and add a CLI command to download and verify GCP OVMF files.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Upgrade Python virtual environment setup to include setuptools and wheel, append computation ID to Docker container names, and improve test robustness with error assertions and conditional skips for runtime tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: Enhance attestation verification tests, including CoRIM integration and specific platform types like Azure SNP, vTPM, TDX, and IGVM.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add comprehensive test cases for `VerifyWithCoRIM` including success and measurement mismatch, and refine reference value validation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add Azure and TDX attestation verification tests and abstract external service dependencies for improved testability.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add new test cases for Azure measurement extraction, EAT platform types, IGVM measurement stopping, vTPM CoRIM verification, and GCP OVMF download CLI.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: enhance CLI CoRIM generation and ATLS certificate verification tests, and refactor the Azure MAA client to use an interface.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2026-03-19 17:01:24 +01:00

487 lines
13 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package manager
import (
"context"
"crypto"
"encoding/base64"
"fmt"
"log/slog"
"net"
"os"
"regexp"
"strconv"
"sync"
"syscall"
"time"
"github.com/absmach/supermq/pkg/errors"
"github.com/google/uuid"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/pkg/attestation/corimgen"
"github.com/ultravioletrs/cocos/pkg/manager"
"golang.org/x/crypto/sha3"
)
const (
persistenceDir = "/tmp/cocos"
agentLogLevelKey = "AGENT_LOG_LEVEL"
agentCvmGrpcUrlKey = "AGENT_CVM_GRPC_URL"
agentCvmClientCertKey = "AGENT_CVM_GRPC_CLIENT_CERT"
agentCvmClientKey = "AGENT_CVM_GRPC_CLIENT_KEY"
agentCvmServerCaCertKey = "AGENT_CVM_GRPC_SERVER_CA_CERTS"
agentCvmId = "AGENT_CVM_ID"
agentCaToken = "AGENT_CERTS_TOKEN"
agentCvmCaUrl = "AGENT_CVM_CA_URL"
awsAccessKeyIdKey = "AWS_ACCESS_KEY_ID"
awsSecretAccessKeyKey = "AWS_SECRET_ACCESS_KEY"
awsEndpointUrlKey = "AWS_ENDPOINT_URL"
awsRegionKey = "AWS_REGION"
defClientCertPath = "/etc/certs/cert.pem"
defClientKeyPath = "/etc/certs/key.pem"
defServerCaCertPath = "/etc/certs/ca.pem"
cvmEnvironmentFile = "environment"
)
var (
// ErrMalformedEntity indicates malformed entity specification (e.g.
// invalid username or password).
ErrMalformedEntity = errors.New("malformed entity specification")
// ErrUnauthorizedAccess indicates missing or invalid credentials provided
// when accessing a protected resource.
ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided")
// ErrNotFound indicates a non-existent entity request.
ErrNotFound = errors.New("entity not found")
// ErrFailedToAllocatePort indicates no free port was found on host.
ErrFailedToAllocatePort = errors.New("failed to allocate free port on host")
// ErrFailedToCalculateHash indicates that agent computation returned an error while calculating the hash of the computation.
ErrFailedToCalculateHash = errors.New("error while calculating the hash of the computation")
// ErrFailedToCreateAttestationPolicy indicates that the script to create the attestation policy failed to execute.
ErrFailedToCreateAttestationPolicy = errors.New("error while creating attestation policy")
// ErrFailedToReadPolicy indicates that the file for attestation policy could not be opened.
ErrFailedToReadPolicy = errors.New("error while opening file attestation policy")
// ErrUnmarshalFailed indicates that the file for the attestation policy could not be unmarshaled.
ErrUnmarshalFailed = errors.New("error while unmarshaling the attestation policy")
// ErrMaxVMsExceeded indicates that the maximum number of VMs has been reached.
ErrMaxVMsExceeded = errors.New("maximum number of VMs exceeded")
)
// Service specifies an API that must be fulfilled by the domain service
// implementation, and all of its decorators (e.g. logging & metrics).
type Service interface {
// Run create a computation.
CreateVM(ctx context.Context, req *CreateReq) (string, string, error)
// Stop stops a computation.
RemoveVM(ctx context.Context, computationID string) error
// FetchAttestationPolicy measures and fetches the attestation policy.
FetchAttestationPolicy(ctx context.Context, computationID string) ([]byte, error)
// ReturnCVMInfo returns CVM information needed for attestation verification and validation.
ReturnCVMInfo(ctx context.Context) (string, int, string, string)
// Shutdown gracefully shuts down the service
Shutdown() error
}
type managerService struct {
mu sync.Mutex
qemuCfg qemu.Config
attestationPolicyBinaryPath string
pcrValuesFilePath string
signingKey crypto.Signer
logger *slog.Logger
vms map[string]vm.VM
vmFactory vm.Provider
portRangeMin int
portRangeMax int
persistence qemu.Persistence
eosVersion string
ttlManager *TTLManager
maxVMs int
}
var _ Service = (*managerService)(nil)
// New instantiates the manager service implementation.
func New(cfg qemu.Config, attestationPolicyBinaryPath string, pcrValuesFilePath string, signingKeyPath string, logger *slog.Logger, vmFactory vm.Provider, eosVersion string, maxVMs int) (Service, error) {
start, end, err := decodeRange(cfg.HostFwdRange)
if err != nil {
return nil, err
}
persistence, err := qemu.NewFilePersistence(persistenceDir)
if err != nil {
return nil, err
}
var signingKey crypto.Signer
if signingKeyPath != "" {
key, err := corimgen.LoadSigningKey(signingKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to load signing key: %w", err)
}
signingKey = key
}
ms := &managerService{
qemuCfg: cfg,
logger: logger,
vms: make(map[string]vm.VM),
vmFactory: vmFactory,
attestationPolicyBinaryPath: attestationPolicyBinaryPath,
pcrValuesFilePath: pcrValuesFilePath,
signingKey: signingKey,
portRangeMin: start,
portRangeMax: end,
persistence: persistence,
eosVersion: eosVersion,
ttlManager: NewTTLManager(),
maxVMs: maxVMs,
}
if err := ms.restoreVMs(); err != nil {
return nil, err
}
return ms, nil
}
func (ms *managerService) CreateVM(ctx context.Context, req *CreateReq) (string, string, error) {
id := uuid.New().String()
ms.mu.Lock()
if ms.maxVMs > 0 && len(ms.vms) >= ms.maxVMs {
ms.mu.Unlock()
return "", id, ErrMaxVMsExceeded
}
cfg := qemu.VMInfo{
Config: ms.qemuCfg,
LaunchTCB: 0,
}
if req.AaKbsParams != "" {
cfg.Config.KernelCommandLine = fmt.Sprintf("%s agent.aa_kbc_params=%s", cfg.Config.KernelCommandLine, req.AaKbsParams)
}
ms.mu.Unlock()
tmpCertsDir, err := tempCertMount(id, req)
if err != nil {
return "", id, err
}
tmpEnvDir, err := tmpEnvironment(id, req)
if err != nil {
return "", id, err
}
cfg.Config.CertsMount = tmpCertsDir
cfg.Config.EnvMount = tmpEnvDir
// LaunchTCB will be set to 0 by default in qemu.VMInfo
// It's used for attestation verification, not VM creation
agentPort, err := getFreePort(ms.portRangeMin, ms.portRangeMax)
if err != nil {
return "", id, errors.Wrap(ErrFailedToAllocatePort, err)
}
cfg.Config.HostFwdAgent = agentPort
if cfg.Config.EnableSEVSNP {
todo := sha3.Sum256([]byte("TODO"))
// Define host-data value of QEMU for SEV-SNP, with a base64 encoding of the computation hash.
cfg.Config.SEVSNPConfig.HostData = base64.StdEncoding.EncodeToString(todo[:])
}
cvm := ms.vmFactory(cfg, id, ms.logger)
if err = cvm.Start(); err != nil {
return "", id, err
}
ms.mu.Lock()
if ms.maxVMs > 0 && len(ms.vms) >= ms.maxVMs {
ms.mu.Unlock()
if stopErr := cvm.Stop(); stopErr != nil {
ms.logger.Error("Failed to stop VM after exceeding max limit", "vmID", id, "error", stopErr)
}
return "", id, ErrMaxVMsExceeded
}
ms.vms[id] = cvm
ms.mu.Unlock()
if req.Ttl != "" {
ttl, err := time.ParseDuration(req.Ttl)
if err != nil {
return "", id, err
}
ms.ttlManager.SetTTL(id, ttl, func() { //nolint:contextcheck
if err := ms.RemoveVM(context.Background(), id); err != nil {
ms.logger.Error("Failed to remove VM after TTL expiry", "vmID", id, "error", err)
} else {
ms.logger.Info("Successfully removed VM after TTL expiry", "vmID", id)
}
})
}
pid := cvm.GetProcess()
state := qemu.VMState{
ID: id,
VMinfo: cfg,
PID: pid,
}
if err := ms.persistence.SaveVM(state); err != nil {
ms.logger.Error("Failed to persist VM state", "error", err)
}
ms.mu.Lock()
if err := ms.vms[id].Transition(manager.VmRunning); err != nil {
ms.logger.Warn("Failed to transition VM state", "cvm", id, "error", err)
}
ms.mu.Unlock()
return fmt.Sprint(agentPort), id, nil
}
func (ms *managerService) RemoveVM(ctx context.Context, computationID string) error {
ms.mu.Lock()
defer ms.mu.Unlock()
ms.ttlManager.CancelTTL(computationID)
cvm, ok := ms.vms[computationID]
if !ok {
return ErrNotFound
}
if err := cvm.Stop(); err != nil {
return err
}
delete(ms.vms, computationID)
if err := ms.persistence.DeleteVM(computationID); err != nil {
ms.logger.Error("Failed to delete persisted VM state", "error", err)
}
return nil
}
func (ms *managerService) ReturnCVMInfo(ctx context.Context) (string, int, string, string) {
return ms.qemuCfg.OVMFCodeConfig.Version, ms.qemuCfg.SMPCount, ms.qemuCfg.CPU, ms.eosVersion
}
// Shutdown gracefully shuts down the service.
func (ms *managerService) Shutdown() error {
ms.logger.Info("Shutting down manager service")
ms.ttlManager.CancelAll()
ms.mu.Lock()
defer ms.mu.Unlock()
ms.vms = make(map[string]vm.VM)
return nil
}
func getFreePort(minPort, maxPort int) (int, error) {
if checkPortisFree(minPort) {
return minPort, nil
}
var wg sync.WaitGroup
portCh := make(chan int, 1)
for port := minPort; port <= maxPort; port++ {
wg.Add(1)
go func(p int) {
defer wg.Done()
if checkPortisFree(p) {
select {
case portCh <- p:
default:
}
}
}(port)
}
go func() {
wg.Wait()
close(portCh)
}()
port, ok := <-portCh
if !ok {
return 0, fmt.Errorf("failed to find free port in range %d-%d", minPort, maxPort)
}
return port, nil
}
func checkPortisFree(port int) bool {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return false
}
defer listener.Close()
return true
}
func decodeRange(input string) (int, int, error) {
re := regexp.MustCompile(`(\d+)-(\d+)`)
matches := re.FindStringSubmatch(input)
if len(matches) != 3 {
return 0, 0, fmt.Errorf("invalid input format: %s", input)
}
start, err := strconv.Atoi(matches[1])
if err != nil {
return 0, 0, err
}
end, err := strconv.Atoi(matches[2])
if err != nil {
return 0, 0, err
}
if start > end {
return 0, 0, fmt.Errorf("invalid range: %d-%d", start, end)
}
return start, end, nil
}
func (ms *managerService) restoreVMs() error {
states, err := ms.persistence.LoadVMs()
if err != nil {
return err
}
for _, state := range states {
if !ms.processExists(state.PID) {
if err := ms.persistence.DeleteVM(state.ID); err != nil {
ms.logger.Error("Failed to delete persisted VM state", "computation", state.ID, "error", err)
}
ms.logger.Info("Deleted persisted state for non-existent process", "computation", state.ID, "pid", state.PID)
continue
}
cvm := ms.vmFactory(state.VMinfo, state.ID, ms.logger)
if err = cvm.SetProcess(state.PID); err != nil {
ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err)
continue
}
if err := cvm.Transition(manager.VmRunning); err != nil {
ms.logger.Warn("Failed to transition VM state", "computation", state.ID, "error", err)
}
ms.vms[state.ID] = cvm
ms.logger.Info("Successfully restored VM state", "id", state.ID, "computationId", state.ID, "pid", state.PID)
}
return nil
}
func (ms *managerService) processExists(pid int) bool {
process, err := os.FindProcess(pid)
if err != nil {
ms.logger.Warn("Failed to find process", "pid", pid, "error", err)
return false
}
if err = process.Signal(syscall.Signal(0)); err == nil {
return true
}
if err == syscall.ESRCH {
return false
}
return false
}
func tempCertMount(id string, req *CreateReq) (string, error) {
dir, err := os.MkdirTemp("/tmp", id)
if err != nil {
return "", err
}
if err = os.WriteFile(fmt.Sprintf("%s/%s", dir, "cert.pem"), req.AgentCvmClientCert, 0o644); err != nil {
return "", err
}
if err = os.WriteFile(fmt.Sprintf("%s/%s", dir, "key.pem"), req.AgentCvmClientKey, 0o644); err != nil {
return "", err
}
if err = os.WriteFile(fmt.Sprintf("%s/%s", dir, "ca.pem"), req.AgentCvmServerCaCert, 0o644); err != nil {
return "", err
}
return dir, nil
}
func tmpEnvironment(id string, req *CreateReq) (string, error) {
dir, err := os.MkdirTemp("/tmp", id)
if err != nil {
return "", err
}
envMap := map[string]string{
agentLogLevelKey: req.AgentLogLevel,
agentCvmGrpcUrlKey: req.AgentCvmServerUrl,
agentCvmId: id,
agentCaToken: req.AgentCertsToken,
agentCvmCaUrl: req.AgentCvmCaUrl,
}
if req.AgentCvmClientCert != nil {
envMap[agentCvmClientCertKey] = defClientCertPath
}
if req.AgentCvmClientKey != nil {
envMap[agentCvmClientKey] = defClientKeyPath
}
if req.AgentCvmServerCaCert != nil {
envMap[agentCvmServerCaCertKey] = defServerCaCertPath
}
// Add AWS credentials if provided
if req.AwsAccessKeyId != "" {
envMap[awsAccessKeyIdKey] = req.AwsAccessKeyId
}
if req.AwsSecretAccessKey != "" {
envMap[awsSecretAccessKeyKey] = req.AwsSecretAccessKey
}
if req.AwsEndpointUrl != "" {
envMap[awsEndpointUrlKey] = req.AwsEndpointUrl
}
if req.AwsRegion != "" {
envMap[awsRegionKey] = req.AwsRegion
}
envFile, err := os.OpenFile(fmt.Sprintf("%s/%s", dir, cvmEnvironmentFile), os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return "", err
}
for k, v := range envMap {
if _, err = envFile.WriteString(fmt.Sprintf("%s=%s\n", k, v)); err != nil {
return "", err
}
}
if err = envFile.Close(); err != nil {
return "", err
}
return dir, nil
}