COCOS-160: Enable mTLS when using aTLS (#434)

* added maTLS feature to agent and cli

* added maTLS feature to agent and cli

* added tests and fixed one bug

* fixed according to comment

* fixed test

* fixed
This commit is contained in:
Jovan Djukic
2025-05-26 21:54:15 +02:00
committed by GitHub
parent 90807d9576
commit bda3968fdf
7 changed files with 124 additions and 64 deletions
+1 -1
View File
@@ -20,7 +20,7 @@ func NewAgentClient(ctx context.Context, cfg grpc.AgentClientConfig) (grpc.Clien
return nil, nil, err
}
if client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
if client.Secure() != grpc.WithMATLS && client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "agent",
+18 -6
View File
@@ -22,10 +22,11 @@ import (
"google.golang.org/grpc/credentials"
)
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error) {
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, security, error) {
security := withaTLS
err := attestation.ReadAttestationPolicy(cfg.AttestationPolicy, &attestation.AttestationPolicy)
if err != nil {
return nil, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err)
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err)
}
var insecureSkipVerify bool = true
@@ -37,23 +38,25 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error)
// Read the certificate file
certPEM, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, errors.Wrap(fmt.Errorf("failed to read certificate file"), err)
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to read certificate file"), err)
}
// Decode the PEM block
block, _ := pem.Decode(certPEM)
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block")
return nil, withoutTLS, fmt.Errorf("failed to decode PEM block")
}
// Parse the certificate
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(fmt.Errorf("failed to parse certificate"), err)
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to parse certificate"), err)
}
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
security = withmaTLS
}
tlsConfig := &tls.Config{
@@ -63,7 +66,16 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error)
return verifyPeerCertificateATLS(rawCerts, verifiedChains, cfg)
},
}
return credentials.NewTLS(tlsConfig), nil
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, withoutTLS, errors.Wrap(errFailedToLoadClientCertKey, err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
}
return credentials.NewTLS(tlsConfig), security, nil
}
func CustomDialer(ctx context.Context, addr string) (net.Conn, error) {
+6 -1
View File
@@ -188,7 +188,12 @@ func TestClientSecure(t *testing.T) {
{
name: "With aTLS",
secure: withaTLS,
expected: WithATLS,
expected: "with aTLS",
},
{
name: "With maTLS",
secure: withmaTLS,
expected: WithMATLS,
},
}
+11 -5
View File
@@ -24,11 +24,14 @@ const (
withTLS
withmTLS
withaTLS
withmaTLS
)
const (
WithATLS = "with aTLS"
WithTLS = "with TLS"
AttestationReportSize = 0x4A0
WithMATLS = "with maTLS"
WithATLS = "with aTLS"
WithTLS = "with TLS"
)
var (
@@ -120,7 +123,9 @@ func (c *client) Secure() string {
case withmTLS:
return "with mTLS"
case withaTLS:
return WithATLS
return "with aTLS"
case withmaTLS:
return WithMATLS
default:
return "without TLS"
}
@@ -137,14 +142,15 @@ func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) {
secure := withoutTLS
if agcfg, ok := cfg.(AgentClientConfig); ok && agcfg.AttestedTLS {
tc, err := setupATLS(agcfg)
tc, sec, err := setupATLS(agcfg)
if err != nil {
return nil, secure, err
}
opts = append(opts, grpc.WithTransportCredentials(tc))
opts = append(opts, grpc.WithContextDialer(CustomDialer))
secure = withaTLS
secure = sec
} else {
conf := cfg.GetBaseConfig()
transportCreds, sec, err := loadTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)