Files
cocos/internal/server/grpc/grpc_test.go
T
Danko Miladinovic 67f939fc66
CI / checkproto (push) Has been cancelled
CI / ci (push) Has been cancelled
Rust CI Pipeline / rust-check (push) Has been cancelled
COCOS-326 - Add vTPM support to CoCoS (#376)
* manager, cli and agent vtpm support

* rebase and changed atls for vtpm

* deleted unused code

* changed chekproto.yaml script so it find the manager proto file correctly

* fixe manager proto version

* fix agent tests

* fix server agent test

* fix attestation test

* fix attestation test gofumpt

* created dummy RWC for TPM

* fix comment

* add default PCR values

* rebase main

* fix rust ci and missing header

* changed embedded  attestation to VMPL 2

* fix unused impot

* fix pkg test

* address attestation type

* fix agent attestation test

* add prc15 check

* fix comments

* fix cli tests

* add doc

* add mock for LeveledQuoteProvider when SEV-SNP device is not found

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix manager reading attestation policy

* refactor PCR value checks and update attestation policy values

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix tests for sev and grpc

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
Co-authored-by: Sammy Oina <sammyoina@gmail.com>
2025-03-07 16:36:47 +01:00

560 lines
14 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"log/slog"
"math/big"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks"
"github.com/ultravioletrs/cocos/internal/server"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
)
const bufSize = 1024 * 1024
var lis *bufconn.Listener
type DummyRWC struct{}
// Read fills p with byte(len(p)) and returns len(p).
func (l *DummyRWC) Read(p []byte) (int, error) {
n := len(p)
// Fill each byte in p with the value of n as a byte.
for i := range p {
p[i] = byte(n)
}
return n, nil
}
// Write simply returns len(p) indicating that all bytes were written.
func (l *DummyRWC) Write(p []byte) (int, error) {
// In this simple implementation, we ignore the data.
return len(p), nil
}
func (l *DummyRWC) Close() error {
return nil
}
func init() {
lis = bufconn.Listen(bufSize)
}
func TestNew(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "50051",
},
},
}
logger := slog.Default()
qp := new(mocks.LeveledQuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
assert.NotNil(t, srv)
assert.IsType(t, &Server{}, srv)
}
func TestServerStartWithTLSFile(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
certFile, err := os.CreateTemp("", "cert*.pem")
assert.NoError(t, err)
keyFile, err := os.CreateTemp("", "key*.pem")
assert.NoError(t, err)
t.Cleanup(func() {
os.Remove(certFile.Name())
os.Remove(keyFile.Name())
})
_, err = certFile.Write(cert)
assert.NoError(t, err)
_, err = keyFile.Write(key)
assert.NoError(t, err)
err = certFile.Close()
assert.NoError(t, err)
err = keyFile.Close()
assert.NoError(t, err)
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
CertFile: certFile.Name(),
KeyFile: keyFile.Name(),
},
},
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.LeveledQuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
assert.NoError(t, err)
}()
wg.Wait()
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
logContent := logBuffer.String()
fmt.Println(logContent)
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
}
func TestServerStartWithmTLSFile(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles()
assert.NoError(t, err)
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
CertFile: string(clientCertFile),
KeyFile: string(clientKeyFile),
ServerCAFile: caCertFile,
},
},
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.LeveledQuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
assert.NoError(t, err)
}()
wg.Wait()
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
logContent := logBuffer.String()
fmt.Println(logContent)
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
}
func TestServerStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
},
},
}
buf := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.LeveledQuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
go func() {
err := srv.Start()
assert.NoError(t, err)
}()
time.Sleep(100 * time.Millisecond)
cancel()
time.Sleep(100 * time.Millisecond)
err := srv.Stop()
assert.NoError(t, err)
assert.Contains(t, buf.String(), "TestServer gRPC service shutdown at localhost:0")
}
func generateSelfSignedCert() ([]byte, []byte, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
cert, err := generateSelfSignedCertFromKey(key)
if err != nil {
return nil, nil, err
}
return cert, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), nil
}
func generateSelfSignedCertFromKey(key *rsa.PrivateKey) ([]byte, error) {
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), nil
}
type ThreadSafeBuffer struct {
buffer strings.Builder
mu sync.Mutex
}
func (b *ThreadSafeBuffer) Write(p []byte) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buffer.Write(p)
}
func (b *ThreadSafeBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buffer.String()
}
func TestServerInitializationAndStartup(t *testing.T) {
vtpm.ExternalTPM = &DummyRWC{}
testCases := []struct {
name string
config server.AgentConfig
expectedLog string
expectError bool
setupCallback func(*testing.T, *server.AgentConfig, *ThreadSafeBuffer)
}{
{
name: "Non-TLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
},
},
},
expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS",
},
{
name: "TLS Server Startup with Self-Signed Certificate",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
},
},
},
setupCallback: setupTLSConfig,
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
},
{
name: "TLS Server Startup with Invalid Certificates",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
CertFile: "invalid",
KeyFile: "invalid",
},
},
},
expectError: true,
expectedLog: "failed to load auth certificates",
},
{
name: "mTLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
},
},
},
setupCallback: setupMTLSConfig,
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
},
{
name: "mTLS Server Startup with Invalid Root CA",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
},
},
setupCallback: setupInvalidRootCAConfig,
expectError: true,
expectedLog: "failed to append root ca to tls.Config",
},
{
name: "mTLS Server Startup with Invalid Client CA",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
},
},
setupCallback: setupInvalidClientCAConfig,
expectError: true,
expectedLog: "failed to append client ca to tls.Config",
},
{
name: "Attested TLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
},
},
AttestedTLS: true,
},
expectedLog: "TestServer service gRPC server listening at localhost:0 with Attested TLS",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if tc.setupCallback != nil {
tc.setupCallback(t, &tc.config, nil)
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.LeveledQuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, qp, authSvc)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
if tc.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedLog)
} else {
assert.NoError(t, err)
}
}()
wg.Wait()
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
if !tc.expectError {
logContent := logBuffer.String()
fmt.Println(logContent)
assert.Contains(t, logContent, tc.expectedLog)
}
})
}
}
func setupTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
}
func setupMTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ServerCAFile = string(cert)
config.ClientCAFile = string(cert)
}
func setupInvalidRootCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ServerCAFile = "invalid"
config.ClientCAFile = string(cert)
}
func setupInvalidClientCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ClientCAFile = "invalid"
config.ServerCAFile = string(cert)
}
func createCertificatesFiles() (string, string, string, error) {
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", "", "", err
}
caTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey)
if err != nil {
return "", "", "", err
}
caCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER}))
if err != nil {
return "", "", "", err
}
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", "", "", err
}
clientTemplate := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey)
if err != nil {
return "", "", "", err
}
clientCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER}))
if err != nil {
return "", "", "", err
}
clientKeyFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)}))
if err != nil {
return "", "", "", err
}
return caCertFile, clientCertFile, clientKeyFile, nil
}
func createTempFile(data []byte) (string, error) {
file, err := createTempFileHandle()
if err != nil {
return "", err
}
_, err = file.Write(data)
if err != nil {
return "", err
}
err = file.Close()
if err != nil {
return "", err
}
return file.Name(), nil
}
func createTempFileHandle() (*os.File, error) {
return os.CreateTemp("", "test")
}