mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-474 - New aTLS implementation (#475)
* initial new aTLS * add CA API call for aTLS
This commit is contained in:
committed by
GitHub
parent
9c8ddfd2b1
commit
698bd948ed
@@ -4,37 +4,20 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
certs "github.com/absmach/certs"
|
||||
certscli "github.com/absmach/certs/cli"
|
||||
"github.com/absmach/certs/errors"
|
||||
certssdk "github.com/absmach/certs/sdk"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -67,10 +50,6 @@ type Server struct {
|
||||
cvmId string
|
||||
}
|
||||
|
||||
type csrReq struct {
|
||||
CSR string `json:"csr,omitempty"`
|
||||
}
|
||||
|
||||
type serviceRegister func(srv *grpc.Server)
|
||||
|
||||
var _ server.Server = (*Server)(nil)
|
||||
@@ -107,23 +86,12 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
creds := grpc.Creds(insecure.NewCredentials())
|
||||
var listener net.Listener
|
||||
|
||||
c := s.Config.GetBaseConfig()
|
||||
if agCfg, ok := s.Config.(server.AgentConfig); ok && agCfg.AttestedTLS {
|
||||
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.caUrl, s.cvmId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("falied due to invalid key pair: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
ClientAuth: tls.NoClientCert,
|
||||
GetCertificate: atls.GetCertificate(s.caUrl, s.cvmId),
|
||||
}
|
||||
|
||||
var mtls bool
|
||||
@@ -163,15 +131,7 @@ func (s *Server) Start() error {
|
||||
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
|
||||
|
||||
listener, err = atls.Listen(
|
||||
s.Address,
|
||||
certificateBytes,
|
||||
privateKeyBytes,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Listener for aTLS: %w", err)
|
||||
} else if mtls {
|
||||
if mtls {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
|
||||
@@ -227,22 +187,16 @@ func (s *Server) Start() error {
|
||||
default:
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, c.CertFile, c.KeyFile))
|
||||
}
|
||||
|
||||
listener, err = net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
default:
|
||||
var err error
|
||||
|
||||
listener, err = net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
|
||||
}
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
|
||||
grpcServerOptions = append(grpcServerOptions, creds)
|
||||
|
||||
s.server = grpc.NewServer(grpcServerOptions...)
|
||||
@@ -313,171 +267,3 @@ func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
|
||||
|
||||
return tls.X509KeyPair(cert, key)
|
||||
}
|
||||
|
||||
func generateCertificatesForATLS(caUrl string, cvmId string) ([]byte, []byte, error) {
|
||||
curve := elliptic.P256()
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate private/public key: %w", err)
|
||||
}
|
||||
|
||||
var certDERBytes []byte
|
||||
|
||||
if caUrl == "" || cvmId == "" {
|
||||
certTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(202403311),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{organization},
|
||||
Country: []string{country},
|
||||
Province: []string{province},
|
||||
Locality: []string{locality},
|
||||
StreetAddress: []string{streetAddress},
|
||||
PostalCode: []string{postalCode},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
DERBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
certDERBytes = DERBytes
|
||||
} else {
|
||||
csrmd := certs.CSRMetadata{
|
||||
Organization: []string{organization},
|
||||
Country: []string{country},
|
||||
Province: []string{province},
|
||||
Locality: []string{locality},
|
||||
StreetAddress: []string{streetAddress},
|
||||
PostalCode: []string{postalCode},
|
||||
}
|
||||
|
||||
csr, err := certscli.CreateCSR(csrmd, privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create CSR: %w", err)
|
||||
}
|
||||
|
||||
csrData := string(csr.CSR)
|
||||
|
||||
r := csrReq{
|
||||
CSR: csrData,
|
||||
}
|
||||
|
||||
data, error := json.Marshal(r)
|
||||
if error != nil {
|
||||
return nil, nil, errors.NewSDKError(error)
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay)
|
||||
ttlString := notAfter.Sub(notBefore).String()
|
||||
|
||||
query := url.Values{}
|
||||
query.Add("ttl", ttlString)
|
||||
query_string := query.Encode()
|
||||
|
||||
certsEndpoint := "certs"
|
||||
csrEndpoint := "csrs"
|
||||
endpoint := fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, cvmId)
|
||||
|
||||
url := fmt.Sprintf("%s/%s?%s", caUrl, endpoint, query_string)
|
||||
|
||||
_, body, sdkerr := processRequest(http.MethodPost, url, data, nil, http.StatusOK)
|
||||
if sdkerr != nil {
|
||||
return nil, nil, errors.NewSDKError(sdkerr)
|
||||
}
|
||||
|
||||
var cert certssdk.Certificate
|
||||
if err := json.Unmarshal(body, &cert); err != nil {
|
||||
return nil, nil, errors.NewSDKError(err)
|
||||
}
|
||||
|
||||
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
|
||||
|
||||
block, rest := pem.Decode([]byte(cleanCertificateString))
|
||||
|
||||
if len(rest) != 0 {
|
||||
return nil, nil, fmt.Errorf("failed to convert generated certificate to DER format: %s", cleanCertificateString)
|
||||
}
|
||||
|
||||
certDERBytes = block.Bytes
|
||||
}
|
||||
|
||||
certBytes := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDERBytes,
|
||||
})
|
||||
|
||||
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to marshal the private key: %w", err)
|
||||
}
|
||||
|
||||
keyBytes := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: privateKeyBytes,
|
||||
})
|
||||
|
||||
cert, err := x509.ParseCertificate(certDERBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to marshal public key to DER format: %w", err)
|
||||
}
|
||||
|
||||
ccPlatform := attestation.CCPlatform()
|
||||
if ccPlatform != attestation.TDX {
|
||||
if err := vtpm.ExtendPCR(vtpm.PCR15, pubKeyDER); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to extend vTPM PCR with public key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return certBytes, keyBytes, nil
|
||||
}
|
||||
|
||||
func processRequest(method, reqUrl string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) {
|
||||
req, err := http.NewRequest(method, reqUrl, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
|
||||
// Sets a default value for the Content-Type.
|
||||
// Overridden if Content-Type is passed in the headers arguments.
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
sdkerr := errors.CheckError(resp, expectedRespCodes...)
|
||||
if sdkerr != nil {
|
||||
return make(http.Header), []byte{}, sdkerr
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
return resp.Header, body, nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ type ServerConfiguration interface {
|
||||
type BaseConfig struct {
|
||||
Host string `env:"HOST" envDefault:"localhost"`
|
||||
Port string `env:"PORT" envDefault:"7001"`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
CertFile string `env:"SERVER_CERT" envDefault:""`
|
||||
KeyFile string `env:"SERVER_KEY" envDefault:""`
|
||||
ClientCAFile string `env:"CLIENT_CA_CERTS" envDefault:""`
|
||||
|
||||
Reference in New Issue
Block a user