COCOS-160: Enable mTLS when using aTLS (#434)

* added maTLS feature to agent and cli

* added maTLS feature to agent and cli

* added tests and fixed one bug

* fixed according to comment

* fixed test

* fixed
This commit is contained in:
Jovan Djukic
2025-05-26 21:54:15 +02:00
committed by GitHub
parent 90807d9576
commit bda3968fdf
7 changed files with 124 additions and 64 deletions
+43 -2
View File
@@ -108,6 +108,7 @@ func (s *Server) Start() error {
creds := grpc.Creds(insecure.NewCredentials())
var listener net.Listener
c := s.Config.GetBaseConfig()
if agCfg, ok := s.Config.(server.AgentConfig); ok && agCfg.AttestedTLS {
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.caUrl, s.cvmId)
if err != nil {
@@ -124,6 +125,41 @@ func (s *Server) Start() error {
Certificates: []tls.Certificate{certificate},
}
var mtls bool
mtls = false
// Loading Server CA file
rootCA, err := loadCertFile(c.ServerCAFile)
if err != nil {
return fmt.Errorf("failed to load server ca file: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return fmt.Errorf("failed to append server ca to tls.Config")
}
mtls = true
}
// Loading Client CA File
clientCA, err := loadCertFile(c.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return fmt.Errorf("failed to append client ca to tls.Config")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
mtls = true
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
listener, err = atls.Listen(
@@ -131,12 +167,15 @@ func (s *Server) Start() error {
certificateBytes,
privateKeyBytes,
)
if err != nil {
return fmt.Errorf("failed to create Listener for aTLS: %w", err)
} else if mtls {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
} else {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
}
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
} else {
c := s.Config.GetBaseConfig()
switch {
case c.CertFile != "" || c.KeyFile != "":
certificate, err := loadX509KeyPair(c.CertFile, c.KeyFile)
@@ -253,6 +292,8 @@ func readFileOrData(input string) ([]byte, error) {
data, err := os.ReadFile(input)
if err == nil {
return data, nil
} else {
return nil, err
}
}
return []byte(input), nil
+38 -45
View File
@@ -327,60 +327,53 @@ func TestServerInitializationAndStartup(t *testing.T) {
expectedLog: "failed to load auth certificates",
},
{
name: "mTLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
},
},
},
setupCallback: setupMTLSConfig,
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
},
{
name: "mTLS Server Startup with Invalid Root CA",
name: "maTLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
},
},
setupCallback: setupInvalidRootCAConfig,
expectError: true,
expectedLog: "failed to append root ca to tls.Config",
},
{
name: "mTLS Server Startup with Invalid Client CA",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
},
},
setupCallback: setupInvalidClientCAConfig,
expectError: true,
expectedLog: "failed to append client ca to tls.Config",
},
{
name: "Attested TLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
ServerCAFile: "",
ClientCAFile: "",
},
},
AttestedTLS: true,
},
expectedLog: "TestServer service gRPC server listening at localhost:0 with Attested TLS",
setupCallback: setupMTLSConfig,
expectError: false,
expectedLog: "with Attested mTLS",
},
{
name: "maTLS Server Startup with Invalid Server CA file",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
},
AttestedTLS: true,
},
setupCallback: setupInvalidRootCAConfig,
expectError: true,
expectedLog: "failed to load server ca file",
},
{
name: "maTLS Server Startup with Invalid Clinet CA file",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
},
AttestedTLS: true,
},
setupCallback: setupInvalidClientCAConfig,
expectError: true,
expectedLog: "failed to load client ca file",
},
}
+1 -1
View File
@@ -20,7 +20,7 @@ func NewAgentClient(ctx context.Context, cfg grpc.AgentClientConfig) (grpc.Clien
return nil, nil, err
}
if client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
if client.Secure() != grpc.WithMATLS && client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "agent",
+18 -6
View File
@@ -22,10 +22,11 @@ import (
"google.golang.org/grpc/credentials"
)
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error) {
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, security, error) {
security := withaTLS
err := attestation.ReadAttestationPolicy(cfg.AttestationPolicy, &attestation.AttestationPolicy)
if err != nil {
return nil, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err)
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err)
}
var insecureSkipVerify bool = true
@@ -37,23 +38,25 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error)
// Read the certificate file
certPEM, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, errors.Wrap(fmt.Errorf("failed to read certificate file"), err)
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to read certificate file"), err)
}
// Decode the PEM block
block, _ := pem.Decode(certPEM)
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block")
return nil, withoutTLS, fmt.Errorf("failed to decode PEM block")
}
// Parse the certificate
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(fmt.Errorf("failed to parse certificate"), err)
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to parse certificate"), err)
}
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
security = withmaTLS
}
tlsConfig := &tls.Config{
@@ -63,7 +66,16 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error)
return verifyPeerCertificateATLS(rawCerts, verifiedChains, cfg)
},
}
return credentials.NewTLS(tlsConfig), nil
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, withoutTLS, errors.Wrap(errFailedToLoadClientCertKey, err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
}
return credentials.NewTLS(tlsConfig), security, nil
}
func CustomDialer(ctx context.Context, addr string) (net.Conn, error) {
+6 -1
View File
@@ -188,7 +188,12 @@ func TestClientSecure(t *testing.T) {
{
name: "With aTLS",
secure: withaTLS,
expected: WithATLS,
expected: "with aTLS",
},
{
name: "With maTLS",
secure: withmaTLS,
expected: WithMATLS,
},
}
+11 -5
View File
@@ -24,11 +24,14 @@ const (
withTLS
withmTLS
withaTLS
withmaTLS
)
const (
WithATLS = "with aTLS"
WithTLS = "with TLS"
AttestationReportSize = 0x4A0
WithMATLS = "with maTLS"
WithATLS = "with aTLS"
WithTLS = "with TLS"
)
var (
@@ -120,7 +123,9 @@ func (c *client) Secure() string {
case withmTLS:
return "with mTLS"
case withaTLS:
return WithATLS
return "with aTLS"
case withmaTLS:
return WithMATLS
default:
return "without TLS"
}
@@ -137,14 +142,15 @@ func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) {
secure := withoutTLS
if agcfg, ok := cfg.(AgentClientConfig); ok && agcfg.AttestedTLS {
tc, err := setupATLS(agcfg)
tc, sec, err := setupATLS(agcfg)
if err != nil {
return nil, secure, err
}
opts = append(opts, grpc.WithTransportCredentials(tc))
opts = append(opts, grpc.WithContextDialer(CustomDialer))
secure = withaTLS
secure = sec
} else {
conf := cfg.GetBaseConfig()
transportCreds, sec, err := loadTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
+7 -4
View File
@@ -42,6 +42,7 @@ var (
pubKeyFile string
caUrl string
cvmId string
clientCAFile string
)
type svc struct {
@@ -89,8 +90,9 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
Algorithm: &cvms.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
ResultConsumers: []*cvms.ResultConsumer{{UserKey: pubPem.Bytes}},
AgentConfig: &cvms.AgentConfig{
Port: "7002",
AttestedTls: attestedTLS,
Port: "7002",
AttestedTls: attestedTLS,
ClientCaFile: clientCAFile,
},
},
},
@@ -106,8 +108,9 @@ func main() {
flagSet.StringVar(&pubKeyFile, "public-key-path", "", "Path to the public key file")
flagSet.StringVar(&attestedTLSString, "attested-tls-bool", "", "Should aTLS be used, must be 'true' or 'false'")
flagSet.StringVar(&dataPathString, "data-paths", "", "Paths to data sources, list of string separated with commas")
flagSet.StringVar(&caUrl, "ca-url", "", "URL for certificate authority, optional flag that can only be used if aTLS is enabled")
flagSet.StringVar(&cvmId, "cvm-id", "", "UUID for a CVM, optional flag that can only be used if aTLS is enabled")
flagSet.StringVar(&caUrl, "ca-url", "", "URL for certificate authority, must be specified if aTLS is used")
flagSet.StringVar(&cvmId, "cvm-id", "", "UUID for a CVM, must be specified if aTLS is used")
flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path")
flagSetParseError := flagSet.Parse(os.Args[1:])
if flagSetParseError != nil {