mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
8eb1fac9ad
* Refactor and update dependencies in the project - Updated go.sum to replace `github.com/absmach/magistrala` with `github.com/absmach/supermq` across various modules. - Removed VSock configuration from environment variables and QEMU arguments. - Updated QEMU configuration and related tests to remove references to guest CID and VSock. - Added new HTTP transport layer for API endpoints in the manager. - Introduced Prometheus monitoring configuration with alert rules and Alertmanager setup. - Updated service and VM interfaces to remove unused methods and references. - Refactored tests to align with the new structure and dependencies. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add MaxVMs configuration and enforce limit on VM creation Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for HTTP transport handlers and endpoints Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add test case for exceeding maximum number of VMs in TestRun Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in TestHandlerWithCustomRouter to ensure response writing is checked Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Update dependencies to latest versions - Upgrade cel.dev/expr from v0.23.0 to v0.24.0 - Upgrade github.com/absmach/supermq from v0.16.0 to v0.17.0 - Upgrade github.com/cenkalti/backoff from v4.3.0 to v5.0.2 - Upgrade github.com/cncf/xds/go to v0.0.0-20250501225837-2ac532fd4443 - Upgrade github.com/go-chi/chi/v5 from v5.2.1 to v5.2.2 - Upgrade github.com/go-jose/go-jose/v3 from v3.0.3 to v3.0.4 - Upgrade github.com/gofrs/uuid/v5 from v5.3.0 to v5.3.2 - Upgrade github.com/prometheus/client_golang from v1.22.0 to v1.23.0 - Upgrade github.com/prometheus/client_model from v0.6.1 to v0.6.2 - Upgrade github.com/prometheus/common from v0.62.0 to v0.65.0 - Upgrade github.com/prometheus/procfs from v0.15.1 to v0.16.1 - Upgrade go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp from v0.60.0 to v0.62.0 - Upgrade go.opentelemetry.io/otel/exporters/otlp/otlptrace from v1.36.0 to v1.37.0 - Upgrade golang.org/x/crypto from v0.39.0 to v0.40.0 - Upgrade golang.org/x/sys from v0.33.0 to v0.34.0 - Upgrade golang.org/x/text from v0.26.0 to v0.27.0 - Upgrade golang.org/x/time from v0.11.0 to v0.12.0 - Upgrade google.golang.org/grpc from v1.73.0 to v1.74.2 Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
479 lines
13 KiB
Go
479 lines
13 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package manager
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/absmach/supermq/pkg/errors"
|
|
"github.com/google/go-sev-guest/proto/check"
|
|
"github.com/google/uuid"
|
|
"github.com/ultravioletrs/cocos/manager/qemu"
|
|
"github.com/ultravioletrs/cocos/manager/vm"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
|
"github.com/ultravioletrs/cocos/pkg/manager"
|
|
"golang.org/x/crypto/sha3"
|
|
)
|
|
|
|
const (
|
|
hashLength = 32
|
|
persistenceDir = "/tmp/cocos"
|
|
agentLogLevelKey = "AGENT_LOG_LEVEL"
|
|
agentCvmGrpcUrlKey = "AGENT_CVM_GRPC_URL"
|
|
agentCvmClientCertKey = "AGENT_CVM_GRPC_CLIENT_CERT"
|
|
agentCvmClientKey = "AGENT_CVM_GRPC_CLIENT_KEY"
|
|
agentCvmServerCaCertKey = "AGENT_CVM_GRPC_SERVER_CA_CERTS"
|
|
agentCvmId = "AGENT_CVM_ID"
|
|
agentCvmCaUrl = "AGENT_CVM_CA_URL"
|
|
defClientCertPath = "/etc/certs/cert.pem"
|
|
defClientKeyPath = "/etc/certs/key.pem"
|
|
defServerCaCertPath = "/etc/certs/ca.pem"
|
|
cvmEnvironmentFile = "environment"
|
|
)
|
|
|
|
var (
|
|
// ErrMalformedEntity indicates malformed entity specification (e.g.
|
|
// invalid username or password).
|
|
ErrMalformedEntity = errors.New("malformed entity specification")
|
|
|
|
// ErrUnauthorizedAccess indicates missing or invalid credentials provided
|
|
// when accessing a protected resource.
|
|
ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided")
|
|
|
|
// ErrNotFound indicates a non-existent entity request.
|
|
ErrNotFound = errors.New("entity not found")
|
|
|
|
// ErrFailedToAllocatePort indicates no free port was found on host.
|
|
ErrFailedToAllocatePort = errors.New("failed to allocate free port on host")
|
|
|
|
// ErrFailedToCalculateHash indicates that agent computation returned an error while calculating the hash of the computation.
|
|
ErrFailedToCalculateHash = errors.New("error while calculating the hash of the computation")
|
|
|
|
// ErrFailedToCreateAttestationPolicy indicates that the script to create the attestation policy failed to execute.
|
|
ErrFailedToCreateAttestationPolicy = errors.New("error while creating attestation policy")
|
|
|
|
// ErrFailedToReadPolicy indicates that the file for attestation policy could not be opened.
|
|
ErrFailedToReadPolicy = errors.New("error while opening file attestation policy")
|
|
|
|
// ErrUnmarshalFailed indicates that the file for the attestation policy could not be unmarshaled.
|
|
ErrUnmarshalFailed = errors.New("error while unmarshaling the attestation policy")
|
|
|
|
// ErrMaxVMsExceeded indicates that the maximum number of VMs has been reached.
|
|
ErrMaxVMsExceeded = errors.New("maximum number of VMs exceeded")
|
|
)
|
|
|
|
// Service specifies an API that must be fulfilled by the domain service
|
|
// implementation, and all of its decorators (e.g. logging & metrics).
|
|
type Service interface {
|
|
// Run create a computation.
|
|
CreateVM(ctx context.Context, req *CreateReq) (string, string, error)
|
|
// Stop stops a computation.
|
|
RemoveVM(ctx context.Context, computationID string) error
|
|
// FetchAttestationPolicy measures and fetches the attestation policy.
|
|
FetchAttestationPolicy(ctx context.Context, computationID string) ([]byte, error)
|
|
// ReturnCVMInfo returns CVM information needed for attestation verification and validation.
|
|
ReturnCVMInfo(ctx context.Context) (string, int, string, string)
|
|
// Shutdown gracefully shuts down the service
|
|
Shutdown() error
|
|
}
|
|
|
|
type managerService struct {
|
|
mu sync.Mutex
|
|
ap sync.Mutex
|
|
qemuCfg qemu.Config
|
|
attestationPolicyBinaryPath string
|
|
igvmMeasurementBinaryPath string
|
|
pcrValuesFilePath string
|
|
logger *slog.Logger
|
|
vms map[string]vm.VM
|
|
vmFactory vm.Provider
|
|
portRangeMin int
|
|
portRangeMax int
|
|
persistence qemu.Persistence
|
|
eosVersion string
|
|
ttlManager *TTLManager
|
|
maxVMs int
|
|
}
|
|
|
|
var _ Service = (*managerService)(nil)
|
|
|
|
// New instantiates the manager service implementation.
|
|
func New(cfg qemu.Config, attestationPolicyBinPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, logger *slog.Logger, vmFactory vm.Provider, eosVersion string, maxVMs int) (Service, error) {
|
|
start, end, err := decodeRange(cfg.HostFwdRange)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
persistence, err := qemu.NewFilePersistence(persistenceDir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ms := &managerService{
|
|
qemuCfg: cfg,
|
|
logger: logger,
|
|
vms: make(map[string]vm.VM),
|
|
vmFactory: vmFactory,
|
|
attestationPolicyBinaryPath: attestationPolicyBinPath,
|
|
igvmMeasurementBinaryPath: igvmMeasurementBinaryPath,
|
|
pcrValuesFilePath: pcrValuesFilePath,
|
|
portRangeMin: start,
|
|
portRangeMax: end,
|
|
persistence: persistence,
|
|
eosVersion: eosVersion,
|
|
ttlManager: NewTTLManager(),
|
|
maxVMs: maxVMs,
|
|
}
|
|
|
|
if err := ms.restoreVMs(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ms, nil
|
|
}
|
|
|
|
func (ms *managerService) CreateVM(ctx context.Context, req *CreateReq) (string, string, error) {
|
|
id := uuid.New().String()
|
|
|
|
ms.mu.Lock()
|
|
if ms.maxVMs > 0 && len(ms.vms) >= ms.maxVMs {
|
|
ms.mu.Unlock()
|
|
return "", id, ErrMaxVMsExceeded
|
|
}
|
|
|
|
cfg := qemu.VMInfo{
|
|
Config: ms.qemuCfg,
|
|
LaunchTCB: 0,
|
|
}
|
|
ms.mu.Unlock()
|
|
|
|
tmpCertsDir, err := tempCertMount(id, req)
|
|
if err != nil {
|
|
return "", id, err
|
|
}
|
|
|
|
tmpEnvDir, err := tmpEnvironment(id, req)
|
|
if err != nil {
|
|
return "", id, err
|
|
}
|
|
|
|
cfg.Config.CertsMount = tmpCertsDir
|
|
cfg.Config.EnvMount = tmpEnvDir
|
|
|
|
if ms.qemuCfg.EnableSEVSNP {
|
|
attestPolicyCmd, err := fetchSNPAttestationPolicy(ms)
|
|
if err != nil {
|
|
return "", id, err
|
|
}
|
|
|
|
var stdOutByte []byte
|
|
ms.ap.Lock()
|
|
stdOutByte, err = attestPolicyCmd.Run(ms.attestationPolicyBinaryPath)
|
|
ms.ap.Unlock()
|
|
if err != nil {
|
|
return "", id, errors.Wrap(ErrFailedToCreateAttestationPolicy, err)
|
|
}
|
|
|
|
attestationPolicy := attestation.Config{Config: &check.Config{RootOfTrust: &check.RootOfTrust{}, Policy: &check.Policy{}}, PcrConfig: &attestation.PcrConfig{}}
|
|
|
|
if err = vtpm.ReadPolicyFromByte(stdOutByte, &attestationPolicy); err != nil {
|
|
return "", id, errors.Wrap(ErrUnmarshalFailed, err)
|
|
}
|
|
|
|
// Define the TCB that was present at launch of the VM.
|
|
cfg.LaunchTCB = attestationPolicy.Config.Policy.MinimumLaunchTcb
|
|
}
|
|
|
|
agentPort, err := getFreePort(ms.portRangeMin, ms.portRangeMax)
|
|
if err != nil {
|
|
return "", id, errors.Wrap(ErrFailedToAllocatePort, err)
|
|
}
|
|
cfg.Config.HostFwdAgent = agentPort
|
|
|
|
if cfg.Config.EnableSEVSNP {
|
|
todo := sha3.Sum256([]byte("TODO"))
|
|
// Define host-data value of QEMU for SEV-SNP, with a base64 encoding of the computation hash.
|
|
cfg.Config.SEVSNPConfig.HostData = base64.StdEncoding.EncodeToString(todo[:])
|
|
}
|
|
|
|
cvm := ms.vmFactory(cfg, id, ms.logger)
|
|
if err = cvm.Start(); err != nil {
|
|
return "", id, err
|
|
}
|
|
|
|
ms.mu.Lock()
|
|
if ms.maxVMs > 0 && len(ms.vms) >= ms.maxVMs {
|
|
ms.mu.Unlock()
|
|
if stopErr := cvm.Stop(); stopErr != nil {
|
|
ms.logger.Error("Failed to stop VM after exceeding max limit", "vmID", id, "error", stopErr)
|
|
}
|
|
return "", id, ErrMaxVMsExceeded
|
|
}
|
|
ms.vms[id] = cvm
|
|
ms.mu.Unlock()
|
|
|
|
if req.Ttl != "" {
|
|
ttl, err := time.ParseDuration(req.Ttl)
|
|
if err != nil {
|
|
return "", id, err
|
|
}
|
|
|
|
ms.ttlManager.SetTTL(id, ttl, func() { //nolint:contextcheck
|
|
if err := ms.RemoveVM(context.Background(), id); err != nil {
|
|
ms.logger.Error("Failed to remove VM after TTL expiry", "vmID", id, "error", err)
|
|
} else {
|
|
ms.logger.Info("Successfully removed VM after TTL expiry", "vmID", id)
|
|
}
|
|
})
|
|
}
|
|
|
|
pid := cvm.GetProcess()
|
|
|
|
state := qemu.VMState{
|
|
ID: id,
|
|
VMinfo: cfg,
|
|
PID: pid,
|
|
}
|
|
if err := ms.persistence.SaveVM(state); err != nil {
|
|
ms.logger.Error("Failed to persist VM state", "error", err)
|
|
}
|
|
|
|
ms.mu.Lock()
|
|
if err := ms.vms[id].Transition(manager.VmRunning); err != nil {
|
|
ms.logger.Warn("Failed to transition VM state", "cvm", id, "error", err)
|
|
}
|
|
ms.mu.Unlock()
|
|
|
|
return fmt.Sprint(agentPort), id, nil
|
|
}
|
|
|
|
func (ms *managerService) RemoveVM(ctx context.Context, computationID string) error {
|
|
ms.mu.Lock()
|
|
defer ms.mu.Unlock()
|
|
|
|
ms.ttlManager.CancelTTL(computationID)
|
|
|
|
cvm, ok := ms.vms[computationID]
|
|
if !ok {
|
|
return ErrNotFound
|
|
}
|
|
if err := cvm.Stop(); err != nil {
|
|
return err
|
|
}
|
|
delete(ms.vms, computationID)
|
|
|
|
if err := ms.persistence.DeleteVM(computationID); err != nil {
|
|
ms.logger.Error("Failed to delete persisted VM state", "error", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ms *managerService) ReturnCVMInfo(ctx context.Context) (string, int, string, string) {
|
|
return ms.qemuCfg.OVMFCodeConfig.Version, ms.qemuCfg.SMPCount, ms.qemuCfg.CPU, ms.eosVersion
|
|
}
|
|
|
|
// Shutdown gracefully shuts down the service.
|
|
func (ms *managerService) Shutdown() error {
|
|
ms.logger.Info("Shutting down manager service")
|
|
|
|
ms.ttlManager.CancelAll()
|
|
|
|
ms.mu.Lock()
|
|
defer ms.mu.Unlock()
|
|
|
|
ms.vms = make(map[string]vm.VM)
|
|
|
|
return nil
|
|
}
|
|
|
|
func getFreePort(minPort, maxPort int) (int, error) {
|
|
if checkPortisFree(minPort) {
|
|
return minPort, nil
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
portCh := make(chan int, 1)
|
|
|
|
for port := minPort; port <= maxPort; port++ {
|
|
wg.Add(1)
|
|
go func(p int) {
|
|
defer wg.Done()
|
|
if checkPortisFree(p) {
|
|
select {
|
|
case portCh <- p:
|
|
default:
|
|
}
|
|
}
|
|
}(port)
|
|
}
|
|
|
|
go func() {
|
|
wg.Wait()
|
|
close(portCh)
|
|
}()
|
|
|
|
port, ok := <-portCh
|
|
if !ok {
|
|
return 0, fmt.Errorf("failed to find free port in range %d-%d", minPort, maxPort)
|
|
}
|
|
|
|
return port, nil
|
|
}
|
|
|
|
func checkPortisFree(port int) bool {
|
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer listener.Close()
|
|
|
|
return true
|
|
}
|
|
|
|
func decodeRange(input string) (int, int, error) {
|
|
re := regexp.MustCompile(`(\d+)-(\d+)`)
|
|
matches := re.FindStringSubmatch(input)
|
|
if len(matches) != 3 {
|
|
return 0, 0, fmt.Errorf("invalid input format: %s", input)
|
|
}
|
|
|
|
start, err := strconv.Atoi(matches[1])
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
end, err := strconv.Atoi(matches[2])
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
if start > end {
|
|
return 0, 0, fmt.Errorf("invalid range: %d-%d", start, end)
|
|
}
|
|
|
|
return start, end, nil
|
|
}
|
|
|
|
func (ms *managerService) restoreVMs() error {
|
|
states, err := ms.persistence.LoadVMs()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, state := range states {
|
|
if !ms.processExists(state.PID) {
|
|
if err := ms.persistence.DeleteVM(state.ID); err != nil {
|
|
ms.logger.Error("Failed to delete persisted VM state", "computation", state.ID, "error", err)
|
|
}
|
|
ms.logger.Info("Deleted persisted state for non-existent process", "computation", state.ID, "pid", state.PID)
|
|
continue
|
|
}
|
|
|
|
cvm := ms.vmFactory(state.VMinfo, state.ID, ms.logger)
|
|
|
|
if err = cvm.SetProcess(state.PID); err != nil {
|
|
ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err)
|
|
continue
|
|
}
|
|
|
|
if err := cvm.Transition(manager.VmRunning); err != nil {
|
|
ms.logger.Warn("Failed to transition VM state", "computation", state.ID, "error", err)
|
|
}
|
|
|
|
ms.vms[state.ID] = cvm
|
|
ms.logger.Info("Successfully restored VM state", "id", state.ID, "computationId", state.ID, "pid", state.PID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ms *managerService) processExists(pid int) bool {
|
|
process, err := os.FindProcess(pid)
|
|
if err != nil {
|
|
ms.logger.Warn("Failed to find process", "pid", pid, "error", err)
|
|
return false
|
|
}
|
|
|
|
if err = process.Signal(syscall.Signal(0)); err == nil {
|
|
return true
|
|
}
|
|
if err == syscall.ESRCH {
|
|
return false
|
|
}
|
|
return false
|
|
}
|
|
|
|
func tempCertMount(id string, req *CreateReq) (string, error) {
|
|
dir, err := os.MkdirTemp("/tmp", id)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if err = os.WriteFile(fmt.Sprintf("%s/%s", dir, "cert.pem"), req.AgentCvmClientCert, 0o644); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if err = os.WriteFile(fmt.Sprintf("%s/%s", dir, "key.pem"), req.AgentCvmClientKey, 0o644); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if err = os.WriteFile(fmt.Sprintf("%s/%s", dir, "ca.pem"), req.AgentCvmServerCaCert, 0o644); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return dir, nil
|
|
}
|
|
|
|
func tmpEnvironment(id string, req *CreateReq) (string, error) {
|
|
dir, err := os.MkdirTemp("/tmp", id)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
envMap := map[string]string{
|
|
agentLogLevelKey: req.AgentLogLevel,
|
|
agentCvmGrpcUrlKey: req.AgentCvmServerUrl,
|
|
agentCvmId: id,
|
|
agentCvmCaUrl: req.AgentCvmCaUrl,
|
|
}
|
|
|
|
if req.AgentCvmClientCert != nil {
|
|
envMap[agentCvmClientCertKey] = defClientCertPath
|
|
}
|
|
if req.AgentCvmClientKey != nil {
|
|
envMap[agentCvmClientKey] = defClientKeyPath
|
|
}
|
|
if req.AgentCvmServerCaCert != nil {
|
|
envMap[agentCvmServerCaCertKey] = defServerCaCertPath
|
|
}
|
|
|
|
envFile, err := os.OpenFile(fmt.Sprintf("%s/%s", dir, cvmEnvironmentFile), os.O_CREATE|os.O_WRONLY, 0o644)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for k, v := range envMap {
|
|
if _, err = envFile.WriteString(fmt.Sprintf("%s=%s\n", k, v)); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
if err = envFile.Close(); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return dir, nil
|
|
}
|