mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-160: Enable mTLS when using aTLS (#434)
* added maTLS feature to agent and cli * added maTLS feature to agent and cli * added tests and fixed one bug * fixed according to comment * fixed test * fixed
This commit is contained in:
@@ -108,6 +108,7 @@ func (s *Server) Start() error {
|
||||
creds := grpc.Creds(insecure.NewCredentials())
|
||||
var listener net.Listener
|
||||
|
||||
c := s.Config.GetBaseConfig()
|
||||
if agCfg, ok := s.Config.(server.AgentConfig); ok && agCfg.AttestedTLS {
|
||||
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.caUrl, s.cvmId)
|
||||
if err != nil {
|
||||
@@ -124,6 +125,41 @@ func (s *Server) Start() error {
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
}
|
||||
|
||||
var mtls bool
|
||||
mtls = false
|
||||
|
||||
// Loading Server CA file
|
||||
rootCA, err := loadCertFile(c.ServerCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load server ca file: %w", err)
|
||||
}
|
||||
if len(rootCA) > 0 {
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
|
||||
return fmt.Errorf("failed to append server ca to tls.Config")
|
||||
}
|
||||
mtls = true
|
||||
}
|
||||
|
||||
// Loading Client CA File
|
||||
clientCA, err := loadCertFile(c.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load client ca file: %w", err)
|
||||
}
|
||||
if len(clientCA) > 0 {
|
||||
if tlsConfig.ClientCAs == nil {
|
||||
tlsConfig.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
|
||||
return fmt.Errorf("failed to append client ca to tls.Config")
|
||||
}
|
||||
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
mtls = true
|
||||
}
|
||||
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
|
||||
|
||||
listener, err = atls.Listen(
|
||||
@@ -131,12 +167,15 @@ func (s *Server) Start() error {
|
||||
certificateBytes,
|
||||
privateKeyBytes,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Listener for aTLS: %w", err)
|
||||
} else if mtls {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
|
||||
}
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
|
||||
} else {
|
||||
c := s.Config.GetBaseConfig()
|
||||
switch {
|
||||
case c.CertFile != "" || c.KeyFile != "":
|
||||
certificate, err := loadX509KeyPair(c.CertFile, c.KeyFile)
|
||||
@@ -253,6 +292,8 @@ func readFileOrData(input string) ([]byte, error) {
|
||||
data, err := os.ReadFile(input)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return []byte(input), nil
|
||||
|
||||
@@ -327,60 +327,53 @@ func TestServerInitializationAndStartup(t *testing.T) {
|
||||
expectedLog: "failed to load auth certificates",
|
||||
},
|
||||
{
|
||||
name: "mTLS Server Startup",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
},
|
||||
},
|
||||
setupCallback: setupMTLSConfig,
|
||||
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
|
||||
},
|
||||
{
|
||||
name: "mTLS Server Startup with Invalid Root CA",
|
||||
name: "maTLS Server Startup",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
},
|
||||
},
|
||||
setupCallback: setupInvalidRootCAConfig,
|
||||
expectError: true,
|
||||
expectedLog: "failed to append root ca to tls.Config",
|
||||
},
|
||||
{
|
||||
name: "mTLS Server Startup with Invalid Client CA",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
},
|
||||
},
|
||||
setupCallback: setupInvalidClientCAConfig,
|
||||
expectError: true,
|
||||
expectedLog: "failed to append client ca to tls.Config",
|
||||
},
|
||||
{
|
||||
name: "Attested TLS Server Startup",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "",
|
||||
ClientCAFile: "",
|
||||
},
|
||||
},
|
||||
AttestedTLS: true,
|
||||
},
|
||||
expectedLog: "TestServer service gRPC server listening at localhost:0 with Attested TLS",
|
||||
setupCallback: setupMTLSConfig,
|
||||
expectError: false,
|
||||
expectedLog: "with Attested mTLS",
|
||||
},
|
||||
{
|
||||
name: "maTLS Server Startup with Invalid Server CA file",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
},
|
||||
AttestedTLS: true,
|
||||
},
|
||||
setupCallback: setupInvalidRootCAConfig,
|
||||
expectError: true,
|
||||
expectedLog: "failed to load server ca file",
|
||||
},
|
||||
{
|
||||
name: "maTLS Server Startup with Invalid Clinet CA file",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
},
|
||||
AttestedTLS: true,
|
||||
},
|
||||
setupCallback: setupInvalidClientCAConfig,
|
||||
expectError: true,
|
||||
expectedLog: "failed to load client ca file",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ func NewAgentClient(ctx context.Context, cfg grpc.AgentClientConfig) (grpc.Clien
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
|
||||
if client.Secure() != grpc.WithMATLS && client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "agent",
|
||||
|
||||
@@ -22,10 +22,11 @@ import (
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error) {
|
||||
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, security, error) {
|
||||
security := withaTLS
|
||||
err := attestation.ReadAttestationPolicy(cfg.AttestationPolicy, &attestation.AttestationPolicy)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err)
|
||||
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err)
|
||||
}
|
||||
|
||||
var insecureSkipVerify bool = true
|
||||
@@ -37,23 +38,25 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error)
|
||||
// Read the certificate file
|
||||
certPEM, err := os.ReadFile(cfg.ServerCAFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(fmt.Errorf("failed to read certificate file"), err)
|
||||
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to read certificate file"), err)
|
||||
}
|
||||
|
||||
// Decode the PEM block
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode PEM block")
|
||||
return nil, withoutTLS, fmt.Errorf("failed to decode PEM block")
|
||||
}
|
||||
|
||||
// Parse the certificate
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(fmt.Errorf("failed to parse certificate"), err)
|
||||
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to parse certificate"), err)
|
||||
}
|
||||
|
||||
rootCAs = x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
|
||||
security = withmaTLS
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
@@ -63,7 +66,16 @@ func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, error)
|
||||
return verifyPeerCertificateATLS(rawCerts, verifiedChains, cfg)
|
||||
},
|
||||
}
|
||||
return credentials.NewTLS(tlsConfig), nil
|
||||
|
||||
if cfg.ClientCert != "" || cfg.ClientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
|
||||
if err != nil {
|
||||
return nil, withoutTLS, errors.Wrap(errFailedToLoadClientCertKey, err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
}
|
||||
|
||||
return credentials.NewTLS(tlsConfig), security, nil
|
||||
}
|
||||
|
||||
func CustomDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
|
||||
@@ -188,7 +188,12 @@ func TestClientSecure(t *testing.T) {
|
||||
{
|
||||
name: "With aTLS",
|
||||
secure: withaTLS,
|
||||
expected: WithATLS,
|
||||
expected: "with aTLS",
|
||||
},
|
||||
{
|
||||
name: "With maTLS",
|
||||
secure: withmaTLS,
|
||||
expected: WithMATLS,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -24,11 +24,14 @@ const (
|
||||
withTLS
|
||||
withmTLS
|
||||
withaTLS
|
||||
withmaTLS
|
||||
)
|
||||
|
||||
const (
|
||||
WithATLS = "with aTLS"
|
||||
WithTLS = "with TLS"
|
||||
AttestationReportSize = 0x4A0
|
||||
WithMATLS = "with maTLS"
|
||||
WithATLS = "with aTLS"
|
||||
WithTLS = "with TLS"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -120,7 +123,9 @@ func (c *client) Secure() string {
|
||||
case withmTLS:
|
||||
return "with mTLS"
|
||||
case withaTLS:
|
||||
return WithATLS
|
||||
return "with aTLS"
|
||||
case withmaTLS:
|
||||
return WithMATLS
|
||||
default:
|
||||
return "without TLS"
|
||||
}
|
||||
@@ -137,14 +142,15 @@ func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) {
|
||||
secure := withoutTLS
|
||||
|
||||
if agcfg, ok := cfg.(AgentClientConfig); ok && agcfg.AttestedTLS {
|
||||
tc, err := setupATLS(agcfg)
|
||||
tc, sec, err := setupATLS(agcfg)
|
||||
if err != nil {
|
||||
return nil, secure, err
|
||||
}
|
||||
|
||||
opts = append(opts, grpc.WithTransportCredentials(tc))
|
||||
opts = append(opts, grpc.WithContextDialer(CustomDialer))
|
||||
secure = withaTLS
|
||||
|
||||
secure = sec
|
||||
} else {
|
||||
conf := cfg.GetBaseConfig()
|
||||
transportCreds, sec, err := loadTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
|
||||
|
||||
+7
-4
@@ -42,6 +42,7 @@ var (
|
||||
pubKeyFile string
|
||||
caUrl string
|
||||
cvmId string
|
||||
clientCAFile string
|
||||
)
|
||||
|
||||
type svc struct {
|
||||
@@ -89,8 +90,9 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
Algorithm: &cvms.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
|
||||
ResultConsumers: []*cvms.ResultConsumer{{UserKey: pubPem.Bytes}},
|
||||
AgentConfig: &cvms.AgentConfig{
|
||||
Port: "7002",
|
||||
AttestedTls: attestedTLS,
|
||||
Port: "7002",
|
||||
AttestedTls: attestedTLS,
|
||||
ClientCaFile: clientCAFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -106,8 +108,9 @@ func main() {
|
||||
flagSet.StringVar(&pubKeyFile, "public-key-path", "", "Path to the public key file")
|
||||
flagSet.StringVar(&attestedTLSString, "attested-tls-bool", "", "Should aTLS be used, must be 'true' or 'false'")
|
||||
flagSet.StringVar(&dataPathString, "data-paths", "", "Paths to data sources, list of string separated with commas")
|
||||
flagSet.StringVar(&caUrl, "ca-url", "", "URL for certificate authority, optional flag that can only be used if aTLS is enabled")
|
||||
flagSet.StringVar(&cvmId, "cvm-id", "", "UUID for a CVM, optional flag that can only be used if aTLS is enabled")
|
||||
flagSet.StringVar(&caUrl, "ca-url", "", "URL for certificate authority, must be specified if aTLS is used")
|
||||
flagSet.StringVar(&cvmId, "cvm-id", "", "UUID for a CVM, must be specified if aTLS is used")
|
||||
flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path")
|
||||
|
||||
flagSetParseError := flagSet.Parse(os.Args[1:])
|
||||
if flagSetParseError != nil {
|
||||
|
||||
Reference in New Issue
Block a user