mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
8eb1fac9ad
* Refactor and update dependencies in the project - Updated go.sum to replace `github.com/absmach/magistrala` with `github.com/absmach/supermq` across various modules. - Removed VSock configuration from environment variables and QEMU arguments. - Updated QEMU configuration and related tests to remove references to guest CID and VSock. - Added new HTTP transport layer for API endpoints in the manager. - Introduced Prometheus monitoring configuration with alert rules and Alertmanager setup. - Updated service and VM interfaces to remove unused methods and references. - Refactored tests to align with the new structure and dependencies. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add MaxVMs configuration and enforce limit on VM creation Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for HTTP transport handlers and endpoints Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add test case for exceeding maximum number of VMs in TestRun Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in TestHandlerWithCustomRouter to ensure response writing is checked Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Update dependencies to latest versions - Upgrade cel.dev/expr from v0.23.0 to v0.24.0 - Upgrade github.com/absmach/supermq from v0.16.0 to v0.17.0 - Upgrade github.com/cenkalti/backoff from v4.3.0 to v5.0.2 - Upgrade github.com/cncf/xds/go to v0.0.0-20250501225837-2ac532fd4443 - Upgrade github.com/go-chi/chi/v5 from v5.2.1 to v5.2.2 - Upgrade github.com/go-jose/go-jose/v3 from v3.0.3 to v3.0.4 - Upgrade github.com/gofrs/uuid/v5 from v5.3.0 to v5.3.2 - Upgrade github.com/prometheus/client_golang from v1.22.0 to v1.23.0 - Upgrade github.com/prometheus/client_model from v0.6.1 to v0.6.2 - Upgrade github.com/prometheus/common from v0.62.0 to v0.65.0 - Upgrade github.com/prometheus/procfs from v0.15.1 to v0.16.1 - Upgrade go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp from v0.60.0 to v0.62.0 - Upgrade go.opentelemetry.io/otel/exporters/otlp/otlptrace from v1.36.0 to v1.37.0 - Upgrade golang.org/x/crypto from v0.39.0 to v0.40.0 - Upgrade golang.org/x/sys from v0.33.0 to v0.34.0 - Upgrade golang.org/x/text from v0.26.0 to v0.27.0 - Upgrade golang.org/x/time from v0.11.0 to v0.12.0 - Upgrade google.golang.org/grpc from v1.73.0 to v1.74.2 Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
380 lines
9.7 KiB
Go
380 lines
9.7 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package grpc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/absmach/supermq/pkg/errors"
|
|
"github.com/ultravioletrs/cocos/agent"
|
|
"github.com/ultravioletrs/cocos/agent/cvms"
|
|
"github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage"
|
|
"github.com/ultravioletrs/cocos/agent/cvms/server"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
|
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
|
"golang.org/x/sync/errgroup"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
const (
|
|
reconnectInterval = 5 * time.Second
|
|
sendTimeout = 5 * time.Second
|
|
pendingMsgFile = "pending_messages.json"
|
|
)
|
|
|
|
var (
|
|
errCorruptedManifest = errors.New("received manifest may be corrupted")
|
|
errUnknonwMessageType = errors.New("unknown message type")
|
|
)
|
|
|
|
type PendingMessage struct {
|
|
Message *cvms.ClientStreamMessage
|
|
Time time.Time
|
|
}
|
|
|
|
type CVMSClient struct {
|
|
mu sync.Mutex
|
|
stream cvms.Service_ProcessClient
|
|
svc agent.Service
|
|
messageQueue chan *cvms.ClientStreamMessage
|
|
logger *slog.Logger
|
|
runReqManager *runRequestManager
|
|
sp server.AgentServer
|
|
storage storage.Storage
|
|
reconnectFn func(context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error)
|
|
grpcClient pkggrpc.Client
|
|
}
|
|
|
|
// NewClient returns new gRPC client instance.
|
|
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, storageDir string, reconnectFn func(context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error), grpcClient pkggrpc.Client) (*CVMSClient, error) {
|
|
store, err := storage.NewFileStorage(storageDir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &CVMSClient{
|
|
stream: stream,
|
|
svc: svc,
|
|
messageQueue: messageQueue,
|
|
logger: logger,
|
|
runReqManager: newRunRequestManager(),
|
|
sp: sp,
|
|
storage: store,
|
|
reconnectFn: reconnectFn,
|
|
grpcClient: grpcClient,
|
|
}, nil
|
|
}
|
|
|
|
func (client *CVMSClient) Process(ctx context.Context, cancel context.CancelFunc) error {
|
|
for {
|
|
err := client.processWithRetry(ctx)
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
|
|
client.logger.Info("Connection lost, attempting to reconnect...", "error", err)
|
|
time.Sleep(reconnectInterval)
|
|
|
|
grpcClient, stream, err := client.reconnectFn(ctx)
|
|
if err != nil {
|
|
client.logger.Error("Failed to reconnect", "error", err)
|
|
continue
|
|
}
|
|
|
|
client.mu.Lock()
|
|
client.stream = stream
|
|
client.grpcClient = grpcClient
|
|
client.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (client *CVMSClient) processWithRetry(ctx context.Context) error {
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
|
|
eg.Go(func() error {
|
|
return client.handleIncomingMessages(ctx)
|
|
})
|
|
|
|
eg.Go(func() error {
|
|
return client.handleOutgoingMessages(ctx)
|
|
})
|
|
|
|
return eg.Wait()
|
|
}
|
|
|
|
func (client *CVMSClient) handleIncomingMessages(ctx context.Context) error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
req, err := client.stream.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := client.processIncomingMessage(ctx, req); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (client *CVMSClient) handleOutgoingMessages(ctx context.Context) error {
|
|
pendingMsgs, err := client.storage.Load()
|
|
if err != nil {
|
|
client.logger.Error("Failed to load pending messages", "error", err)
|
|
} else {
|
|
client.sendPendingMessages(pendingMsgs)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case msg := <-client.messageQueue:
|
|
if err := client.sendStreamMessage(msg); err != nil {
|
|
if err := client.storage.Add(msg); err != nil {
|
|
client.logger.Error("Failed to store pending message", "error", err)
|
|
}
|
|
client.logger.Error("Failed to send message, stored for retry", "error", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (client *CVMSClient) sendStreamMessage(msg *cvms.ClientStreamMessage) error {
|
|
client.mu.Lock()
|
|
defer client.mu.Unlock()
|
|
|
|
return client.stream.Send(msg)
|
|
}
|
|
|
|
func (client *CVMSClient) sendPendingMessages(pending []storage.Message) {
|
|
for _, pm := range pending {
|
|
if err := client.sendStreamMessage(pm.Message); err != nil {
|
|
if err := client.storage.Add(pm.Message); err != nil {
|
|
client.logger.Error("Failed to store pending message", "error", err)
|
|
}
|
|
client.logger.Error("Failed to resend pending message", "error", err)
|
|
} else {
|
|
client.logger.Info("Successfully resent pending message")
|
|
}
|
|
}
|
|
|
|
if err := client.storage.Clear(); err != nil {
|
|
client.logger.Error("Failed to clear pending messages", "error", err)
|
|
}
|
|
}
|
|
|
|
func (client *CVMSClient) processIncomingMessage(ctx context.Context, req *cvms.ServerStreamMessage) error {
|
|
switch mes := req.Message.(type) {
|
|
case *cvms.ServerStreamMessage_RunReqChunks:
|
|
return client.handleRunReqChunks(ctx, mes)
|
|
case *cvms.ServerStreamMessage_StopComputation:
|
|
go client.handleStopComputation(ctx, mes)
|
|
case *cvms.ServerStreamMessage_AgentStateReq:
|
|
client.handleAgentStateReq(mes)
|
|
case *cvms.ServerStreamMessage_DisconnectReq:
|
|
client.logger.Info("Received disconnect request")
|
|
client.mu.Lock()
|
|
if err := client.grpcClient.Close(); err != nil {
|
|
client.logger.Error("Failed to close gRPC client", "error", err)
|
|
}
|
|
client.mu.Unlock()
|
|
default:
|
|
return errUnknonwMessageType
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (client *CVMSClient) handleAgentStateReq(mes *cvms.ServerStreamMessage_AgentStateReq) {
|
|
state := client.svc.State()
|
|
|
|
msg := &cvms.ClientStreamMessage_AgentStateRes{
|
|
AgentStateRes: &cvms.AgentStateRes{
|
|
State: state,
|
|
Id: mes.AgentStateReq.Id,
|
|
},
|
|
}
|
|
|
|
client.sendMessage(&cvms.ClientStreamMessage{Message: msg})
|
|
}
|
|
|
|
func (client *CVMSClient) handleRunReqChunks(ctx context.Context, msg *cvms.ServerStreamMessage_RunReqChunks) error {
|
|
buffer, complete := client.runReqManager.addChunk(msg.RunReqChunks.Id, msg.RunReqChunks.Data, msg.RunReqChunks.IsLast)
|
|
|
|
if complete {
|
|
var runReq cvms.ComputationRunReq
|
|
if err := proto.Unmarshal(buffer, &runReq); err != nil {
|
|
return errors.Wrap(err, errCorruptedManifest)
|
|
}
|
|
|
|
go client.executeRun(ctx, &runReq)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.ComputationRunReq) {
|
|
ac := agent.Computation{
|
|
ID: runReq.Id,
|
|
Name: runReq.Name,
|
|
Description: runReq.Description,
|
|
}
|
|
|
|
if runReq.Algorithm != nil {
|
|
ac.Algorithm = agent.Algorithm{
|
|
Hash: [32]byte(runReq.Algorithm.Hash),
|
|
UserKey: runReq.Algorithm.UserKey,
|
|
}
|
|
}
|
|
|
|
for _, ds := range runReq.Datasets {
|
|
ac.Datasets = append(ac.Datasets, agent.Dataset{
|
|
Hash: [32]byte(ds.Hash),
|
|
UserKey: ds.UserKey,
|
|
})
|
|
}
|
|
|
|
for _, rc := range runReq.ResultConsumers {
|
|
ac.ResultConsumers = append(ac.ResultConsumers, agent.ResultConsumer{
|
|
UserKey: rc.UserKey,
|
|
})
|
|
}
|
|
|
|
if err := client.svc.InitComputation(ctx, ac); err != nil {
|
|
client.logger.Warn(err.Error())
|
|
return
|
|
}
|
|
|
|
ccPlatform := attestation.CCPlatform()
|
|
|
|
client.mu.Lock()
|
|
defer client.mu.Unlock()
|
|
|
|
if runReq.AgentConfig == nil {
|
|
runReq.AgentConfig = &cvms.AgentConfig{}
|
|
}
|
|
|
|
runRes := &cvms.ClientStreamMessage_RunRes{
|
|
RunRes: &cvms.RunResponse{
|
|
ComputationId: runReq.Id,
|
|
},
|
|
}
|
|
|
|
if err := client.sp.Start(agent.AgentConfig{
|
|
Port: runReq.AgentConfig.Port,
|
|
CertFile: runReq.AgentConfig.CertFile,
|
|
KeyFile: runReq.AgentConfig.KeyFile,
|
|
ServerCAFile: runReq.AgentConfig.ServerCaFile,
|
|
ClientCAFile: runReq.AgentConfig.ClientCaFile,
|
|
AttestedTls: runReq.AgentConfig.AttestedTls,
|
|
}, ac); err != nil {
|
|
client.logger.Warn(err.Error())
|
|
runRes.RunRes.Error = err.Error()
|
|
}
|
|
|
|
defer func() {
|
|
if ccPlatform == attestation.Azure || ccPlatform == attestation.SNPvTPM {
|
|
cmpJson, err := json.Marshal(ac)
|
|
if err != nil {
|
|
client.logger.Error(err.Error())
|
|
return
|
|
}
|
|
if err = vtpm.ExtendPCR(vtpm.PCR16, cmpJson); err != nil {
|
|
client.logger.Error(err.Error())
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
client.sendMessage(&cvms.ClientStreamMessage{Message: runRes})
|
|
}
|
|
|
|
func (client *CVMSClient) handleStopComputation(ctx context.Context, mes *cvms.ServerStreamMessage_StopComputation) {
|
|
msg := &cvms.ClientStreamMessage_StopComputationRes{
|
|
StopComputationRes: &cvms.StopComputationResponse{
|
|
ComputationId: mes.StopComputation.ComputationId,
|
|
},
|
|
}
|
|
if err := client.svc.StopComputation(ctx); err != nil {
|
|
msg.StopComputationRes.Message = err.Error()
|
|
}
|
|
|
|
client.mu.Lock()
|
|
if err := client.sp.Stop(); err != nil {
|
|
msg.StopComputationRes.Message = err.Error()
|
|
}
|
|
client.mu.Unlock()
|
|
|
|
client.sendMessage(&cvms.ClientStreamMessage{Message: msg})
|
|
}
|
|
|
|
func (client *CVMSClient) sendMessage(mes *cvms.ClientStreamMessage) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), sendTimeout)
|
|
defer cancel()
|
|
|
|
select {
|
|
case client.messageQueue <- mes:
|
|
case <-ctx.Done():
|
|
client.logger.Warn("Failed to send message: timeout exceeded")
|
|
}
|
|
}
|
|
|
|
type runRequestManager struct {
|
|
requests map[string]*runRequest
|
|
mu sync.Mutex
|
|
}
|
|
|
|
type runRequest struct {
|
|
buffer []byte
|
|
lastChunk time.Time
|
|
timer *time.Timer
|
|
}
|
|
|
|
func newRunRequestManager() *runRequestManager {
|
|
return &runRequestManager{
|
|
requests: make(map[string]*runRequest),
|
|
}
|
|
}
|
|
|
|
func (m *runRequestManager) addChunk(id string, chunk []byte, isLast bool) ([]byte, bool) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
req, exists := m.requests[id]
|
|
if !exists {
|
|
req = &runRequest{
|
|
buffer: make([]byte, 0),
|
|
lastChunk: time.Now(),
|
|
timer: time.AfterFunc(runReqTimeout, func() { m.timeoutRequest(id) }),
|
|
}
|
|
m.requests[id] = req
|
|
}
|
|
|
|
req.buffer = append(req.buffer, chunk...)
|
|
req.lastChunk = time.Now()
|
|
req.timer.Reset(runReqTimeout)
|
|
|
|
if isLast {
|
|
delete(m.requests, id)
|
|
req.timer.Stop()
|
|
return req.buffer, true
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
func (m *runRequestManager) timeoutRequest(id string) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
delete(m.requests, id)
|
|
// Log timeout or handle it as needed
|
|
}
|