mirror of
https://github.com/cloudflare/cloudflared.git
synced 2026-06-23 04:10:20 +00:00
TUN-10390: Fix missing TLS settings
Check / check (1.22.x, macos-latest) (push) Has been cancelled
Check / check (1.22.x, ubuntu-latest) (push) Has been cancelled
Check / check (1.22.x, windows-latest) (push) Has been cancelled
Semgrep config / semgrep/ci (push) Has been cancelled
Check / check (1.22.x, macos-latest) (push) Has been cancelled
Check / check (1.22.x, ubuntu-latest) (push) Has been cancelled
Check / check (1.22.x, windows-latest) (push) Has been cancelled
Semgrep config / semgrep/ci (push) Has been cancelled
Fixing missing TLS settings. While developing the pre-check probes, I forgot to add the certificate settings, which are essential for establishing a connection to origintunneld. I discovered this while testing cloudflared locally.
This commit is contained in:
@@ -81,6 +81,9 @@ const (
|
||||
// EdgeBindAddress is the command line flag to bind to IP address for outgoing connections to Cloudflare Edge
|
||||
EdgeBindAddress = "edge-bind-address"
|
||||
|
||||
// CACert Certificate Authority authenticating connections with Cloudflare's edge network.
|
||||
CACert = "cacert"
|
||||
|
||||
// Force is the command line flag to specify if you wish to force an action
|
||||
Force = "force"
|
||||
|
||||
|
||||
@@ -642,7 +642,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||
Hidden: false,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: tlsconfig.CaCertFlag,
|
||||
Name: cfdflags.CACert,
|
||||
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",
|
||||
EnvVars: []string{"TUNNEL_CACERT"},
|
||||
Hidden: true,
|
||||
|
||||
@@ -168,7 +168,7 @@ func prepareTunnelConfig(
|
||||
if tlsSettings == nil {
|
||||
return nil, nil, fmt.Errorf("%s has unknown TLS settings", p)
|
||||
}
|
||||
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName)
|
||||
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c.String(flags.CACert), tlsSettings.ServerName)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
||||
}
|
||||
|
||||
+41
-7
@@ -2,6 +2,7 @@ package prechecks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
@@ -54,7 +55,7 @@ func (tr TransportResults) Collect() []CheckResult {
|
||||
//
|
||||
// Each failed probe is retried up to maxRetries times with exponential backoff.
|
||||
// The suite is bounded by cfg.Timeout (defaultTimeout if zero).
|
||||
func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDialers) Report {
|
||||
func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, runDialers RunDialers) Report {
|
||||
runID := uuid.New()
|
||||
|
||||
if cfg.Timeout <= 0 {
|
||||
@@ -63,6 +64,10 @@ func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDia
|
||||
ctx, cancel := context.WithTimeout(ctx, cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Build TLS configs once per protocol
|
||||
quicTLSConfig, quicTLSErr := probeTLSConfig(caCert, connection.QUIC)
|
||||
http2TLSConfig, http2TLSErr := probeTLSConfig(caCert, connection.HTTP2)
|
||||
|
||||
// 1) DNS – must complete before transport probes know which addresses to dial.
|
||||
addrGroups, dnsResults := runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region)
|
||||
|
||||
@@ -81,8 +86,8 @@ func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDia
|
||||
|
||||
if !dnsOK {
|
||||
// DNS failed: emit one skip row per region so the table stays consistent.
|
||||
results.QUIC = skipResultsForRegions(dnsResults, ProbeTypeQUIC, "UDP Connectivity")
|
||||
results.HTTP2 = skipResultsForRegions(dnsResults, ProbeTypeHTTP2, "TCP Connectivity")
|
||||
results.QUIC = skipResultsForRegions(dnsResults, ProbeTypeQUIC, componentUDPConnectivity)
|
||||
results.HTTP2 = skipResultsForRegions(dnsResults, ProbeTypeHTTP2, componentTCPConnectivity)
|
||||
} else {
|
||||
perRegionAddrs := addrsByRegion(addrGroups, cfg.IPVersion)
|
||||
regionTargets := dnsTargets(dnsResults)
|
||||
@@ -91,18 +96,30 @@ func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDia
|
||||
http2Ch := make(chan []CheckResult, 1)
|
||||
|
||||
go func() {
|
||||
quicCh <- probeAllRegions(ctx, ProbeTypeQUIC, "UDP Connectivity",
|
||||
if quicTLSErr != nil {
|
||||
log.Warn().Err(quicTLSErr).Msg("Failed to build QUIC probe TLS config")
|
||||
quicCh <- tlsConfigErrResults(ProbeTypeQUIC, componentUDPConnectivity,
|
||||
regionTargets, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, quicTLSErr), actionQUICBlocked)
|
||||
return
|
||||
}
|
||||
quicCh <- probeAllRegions(ctx, ProbeTypeQUIC, componentUDPConnectivity,
|
||||
perRegionAddrs, regionTargets,
|
||||
func(addr *allregions.EdgeAddr) CheckResult {
|
||||
return probeQUIC(ctx, runDialers.QUICDialer, addr, log)
|
||||
return probeQUIC(ctx, quicTLSConfig, runDialers.QUICDialer, addr, log)
|
||||
})
|
||||
}()
|
||||
|
||||
go func() {
|
||||
http2Ch <- probeAllRegions(ctx, ProbeTypeHTTP2, "TCP Connectivity",
|
||||
if http2TLSErr != nil {
|
||||
log.Warn().Err(http2TLSErr).Msg("Failed to build HTTP/2 probe TLS config")
|
||||
http2Ch <- tlsConfigErrResults(ProbeTypeHTTP2, componentTCPConnectivity,
|
||||
regionTargets, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, http2TLSErr), actionHTTP2Blocked)
|
||||
return
|
||||
}
|
||||
http2Ch <- probeAllRegions(ctx, ProbeTypeHTTP2, componentTCPConnectivity,
|
||||
perRegionAddrs, regionTargets,
|
||||
func(addr *allregions.EdgeAddr) CheckResult {
|
||||
return probeHTTP2(ctx, runDialers.TCPDialer, addr)
|
||||
return probeHTTP2(ctx, http2TLSConfig, runDialers.TCPDialer, addr)
|
||||
})
|
||||
}()
|
||||
|
||||
@@ -119,6 +136,23 @@ func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDia
|
||||
}
|
||||
}
|
||||
|
||||
// tlsConfigErrResults returns one Fail CheckResult per region target, used when
|
||||
// TLS config construction fails before any dial is attempted.
|
||||
func tlsConfigErrResults(probeType ProbeType, component string, regionTargets []string, details, action string) []CheckResult {
|
||||
results := make([]CheckResult, len(regionTargets))
|
||||
for i, target := range regionTargets {
|
||||
results[i] = CheckResult{
|
||||
Type: probeType,
|
||||
Component: component,
|
||||
Target: target,
|
||||
ProbeStatus: Fail,
|
||||
Details: details,
|
||||
Action: action,
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func runDNSProbe(ctx context.Context, resolver DNSResolver, region string) ([][]*allregions.EdgeAddr, []CheckResult) {
|
||||
var addrGroups [][]*allregions.EdgeAddr
|
||||
var dnsResults []CheckResult
|
||||
|
||||
+16
-12
@@ -19,6 +19,10 @@ import (
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
)
|
||||
|
||||
const (
|
||||
emptyCert = ""
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -114,7 +118,7 @@ func TestRun_AllPass(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// 2 DNS + 2 QUIC + 2 HTTP2 + 1 API = 7 results.
|
||||
@@ -145,7 +149,7 @@ func TestRun_QUICBlocked(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// 2 DNS Pass + 2 QUIC Fail + 2 HTTP2 Pass + 1 API Pass.
|
||||
@@ -175,7 +179,7 @@ func TestRun_HTTP2Blocked(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// 2 DNS Pass + 2 QUIC Pass + 2 HTTP2 Fail + 1 API Pass.
|
||||
@@ -205,7 +209,7 @@ func TestRun_BothTransportsBlocked(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// 2 DNS Pass + 2 QUIC Fail + 2 HTTP2 Fail + 1 API Pass.
|
||||
@@ -244,7 +248,7 @@ func TestRun_PartialRegionQUICFail(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// 2 DNS Pass + QUIC-region1 Pass + QUIC-region2 Fail + 2 HTTP2 Pass + 1 API Pass.
|
||||
@@ -277,7 +281,7 @@ func TestRun_DNSFail_SkipsTransports(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// DNS failure emits 2 Fail rows (one per default region).
|
||||
@@ -314,7 +318,7 @@ func TestRun_ManagementAPIFail(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nil, errors.New("connection refused")).AnyTimes()
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// 2 DNS Pass + 2 QUIC Pass + 2 HTTP2 Pass + 1 API Fail.
|
||||
@@ -345,7 +349,7 @@ func TestRun_RegionFlagForwardedToDNS(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Region: "us", Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Region: "us", Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// DNS rows carry regional hostnames (indices 0 and 1).
|
||||
@@ -383,7 +387,7 @@ func TestRun_QUICUsesProbeConnIndex(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
}
|
||||
|
||||
@@ -407,7 +411,7 @@ func TestRun_BothFamiliesProbed(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
// 2 DNS + 2 QUIC + 2 HTTP2 + 1 API = 7 results, all passing.
|
||||
@@ -437,7 +441,7 @@ func TestRun_IPv4OnlySkipsV6(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv4Only},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv4Only},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass)
|
||||
@@ -464,7 +468,7 @@ func TestRun_IPv6OnlySkipsV4(t *testing.T) {
|
||||
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nopConn{}, nil)
|
||||
|
||||
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv6Only},
|
||||
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv6Only},
|
||||
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
|
||||
|
||||
requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass)
|
||||
|
||||
+23
-16
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
edgedial "github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -50,6 +51,7 @@ const (
|
||||
detailsConnectionFailed = "Connection failed"
|
||||
detailsTCPPortReachable = "TCP port reachable (TLS not validated)"
|
||||
detailsDNSPrerequisiteFailed = "DNS prerequisite failed"
|
||||
detailsTLSConfigFailed = "TLS configuration failed"
|
||||
|
||||
// Region hostname templates.
|
||||
region1Global = "region1.v2.argotunnel.com"
|
||||
@@ -103,6 +105,25 @@ func (d *NetManagementDialer) DialContext(ctx context.Context, network, addr str
|
||||
return d.Dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
// probeTLSConfig builds a *tls.Config for a pre-check probe using the same
|
||||
// certificate pool as the production tunnel. The SNI and NextProtos are taken from
|
||||
// p.ProbeTLSSettings() so that the probe SNI is used instead of the production SNI,
|
||||
// which avoids noisy logs in origintunneld.
|
||||
func probeTLSConfig(caCert string, p connection.Protocol) (*tls.Config, error) {
|
||||
settings := p.ProbeTLSSettings()
|
||||
if settings == nil {
|
||||
return nil, fmt.Errorf("no probe TLS settings for protocol %s", p)
|
||||
}
|
||||
cfg, err := tlsconfig.CreateTunnelConfig(caCert, settings.ServerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(settings.NextProtos) > 0 {
|
||||
cfg.NextProtos = settings.NextProtos
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// probeDNS resolves edge addresses for the given region via the supplied
|
||||
// DNSResolver and returns a CheckResult for each region discovered. If
|
||||
// resolution fails for all regions, every result will carry StatusFail.
|
||||
@@ -153,6 +174,7 @@ func probeDNS(
|
||||
// budget.
|
||||
func probeQUIC(
|
||||
ctx context.Context,
|
||||
tlsConfig *tls.Config,
|
||||
dialer QUICDialer,
|
||||
addr *allregions.EdgeAddr,
|
||||
logger *zerolog.Logger,
|
||||
@@ -160,14 +182,6 @@ func probeQUIC(
|
||||
dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
tlsSettings := connection.QUIC.ProbeTLSSettings()
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: tlsSettings.ServerName,
|
||||
NextProtos: tlsSettings.NextProtos,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
CurvePreferences: []tls.CurveID{tls.CurveP256},
|
||||
}
|
||||
|
||||
// We call dialer.DialQuic with isProbe = true, which bypasses connIndex check.
|
||||
// Therefore, whatever we add to connIndex will not be relevant.
|
||||
edgeAddrPort := addr.UDP.AddrPort()
|
||||
@@ -213,14 +227,7 @@ func probeQUIC(
|
||||
//
|
||||
// The dial timeout is capped at perProbeDialTimeout so that a single blocked
|
||||
// dial cannot exhaust the entire suite budget.
|
||||
func probeHTTP2(ctx context.Context, dialer TCPDialer, addr *allregions.EdgeAddr) CheckResult {
|
||||
tlsSettings := connection.HTTP2.ProbeTLSSettings()
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: tlsSettings.ServerName,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CurvePreferences: []tls.CurveID{tls.CurveP256},
|
||||
}
|
||||
|
||||
func probeHTTP2(ctx context.Context, tlsConfig *tls.Config, dialer TCPDialer, addr *allregions.EdgeAddr) CheckResult {
|
||||
conn, err := dialer.DialEdge(ctx, perProbeDialTimeout, tlsConfig, addr.TCP, nil)
|
||||
if err != nil {
|
||||
return CheckResult{
|
||||
|
||||
+13
-10
@@ -2,6 +2,7 @@ package prechecks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
@@ -24,12 +25,14 @@ const (
|
||||
testRegion2US = region2US
|
||||
testRegion1Fed = region1Fed
|
||||
testRegion2Fed = region2Fed
|
||||
testRegion1EU = "eu-region1.v2.argotunnel.com"
|
||||
testRegion2EU = "eu-region2.v2.argotunnel.com"
|
||||
|
||||
testEdgePort = 7844
|
||||
)
|
||||
|
||||
// testTLSConfig is a minimal *tls.Config for tests. Mock dialers never
|
||||
// perform a real TLS handshake, so an empty config is sufficient.
|
||||
var testTLSConfig = &tls.Config{} //nolint:gosec
|
||||
|
||||
// mockQuicConnection is a minimal test double for quic.Connection.
|
||||
type mockQuicConnection struct {
|
||||
closeErr error
|
||||
@@ -231,7 +234,7 @@ func TestProbeQUIC_Success(t *testing.T) {
|
||||
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||
logger := zerolog.New(nil)
|
||||
|
||||
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||
result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
|
||||
|
||||
assert.Equal(t, ProbeTypeQUIC, result.Type)
|
||||
assert.Equal(t, Pass, result.ProbeStatus)
|
||||
@@ -249,7 +252,7 @@ func TestProbeQUIC_DialError(t *testing.T) {
|
||||
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||
logger := zerolog.New(nil)
|
||||
|
||||
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||
result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
|
||||
|
||||
assert.Equal(t, ProbeTypeQUIC, result.Type)
|
||||
assert.Equal(t, Fail, result.ProbeStatus)
|
||||
@@ -269,7 +272,7 @@ func TestProbeQUIC_CloseErrorDoesNotAffectResult(t *testing.T) {
|
||||
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||
logger := zerolog.New(nil)
|
||||
|
||||
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||
result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
|
||||
|
||||
assert.Equal(t, ProbeTypeQUIC, result.Type)
|
||||
assert.Equal(t, Pass, result.ProbeStatus)
|
||||
@@ -287,7 +290,7 @@ func TestProbeQUIC_ContextTimeout(t *testing.T) {
|
||||
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||
logger := zerolog.New(nil)
|
||||
|
||||
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||
result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
|
||||
|
||||
assert.Equal(t, Fail, result.ProbeStatus)
|
||||
assert.Equal(t, detailsHandshakeFailed, result.Details)
|
||||
@@ -305,7 +308,7 @@ func TestProbeHTTP2_Success(t *testing.T) {
|
||||
|
||||
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||
|
||||
result := probeHTTP2(context.Background(), dialer, addr)
|
||||
result := probeHTTP2(context.Background(), testTLSConfig, dialer, addr)
|
||||
|
||||
assert.Equal(t, ProbeTypeHTTP2, result.Type)
|
||||
assert.Equal(t, Pass, result.ProbeStatus)
|
||||
@@ -322,7 +325,7 @@ func TestProbeHTTP2_DialError(t *testing.T) {
|
||||
|
||||
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||
|
||||
result := probeHTTP2(context.Background(), dialer, addr)
|
||||
result := probeHTTP2(context.Background(), testTLSConfig, dialer, addr)
|
||||
|
||||
assert.Equal(t, ProbeTypeHTTP2, result.Type)
|
||||
assert.Equal(t, Fail, result.ProbeStatus)
|
||||
@@ -512,7 +515,7 @@ func TestProbeQUIC_IPv6Address(t *testing.T) {
|
||||
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
|
||||
logger := zerolog.New(nil)
|
||||
|
||||
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||
result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
|
||||
|
||||
assert.Equal(t, Pass, result.ProbeStatus)
|
||||
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
|
||||
@@ -530,7 +533,7 @@ func TestProbeHTTP2_IPv6Address(t *testing.T) {
|
||||
|
||||
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
|
||||
|
||||
result := probeHTTP2(context.Background(), dialer, addr)
|
||||
result := probeHTTP2(context.Background(), testTLSConfig, dialer, addr)
|
||||
|
||||
assert.Equal(t, Pass, result.ProbeStatus)
|
||||
}
|
||||
|
||||
@@ -11,12 +11,10 @@ import (
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
OriginCAPoolFlag = "origin-ca-pool"
|
||||
CaCertFlag = "cacert"
|
||||
)
|
||||
|
||||
// CertReloader can load and reload a TLS certificate from a particular filepath.
|
||||
@@ -65,7 +63,7 @@ func (cr *CertReloader) LoadCert() error {
|
||||
|
||||
// Keep the old certificate if there's a problem reading the new one.
|
||||
if err != nil {
|
||||
sentry.CaptureException(fmt.Errorf("Error parsing X509 key pair: %v", err))
|
||||
sentry.CaptureException(fmt.Errorf("error parsing X509 key pair: %v", err))
|
||||
return err
|
||||
}
|
||||
cr.certificate = &cert
|
||||
@@ -77,6 +75,7 @@ func LoadOriginCA(originCAPoolFilename string, log *zerolog.Logger) (*x509.CertP
|
||||
|
||||
if originCAPoolFilename != "" {
|
||||
var err error
|
||||
// nolint:gosec
|
||||
originCustomCAPool, err = os.ReadFile(originCAPoolFilename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag))
|
||||
@@ -116,6 +115,7 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
// nolint: gosec
|
||||
customOriginCA, err := os.ReadFile(originCAFilename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", originCAFilename))
|
||||
@@ -127,10 +127,10 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
func CreateTunnelConfig(c *cli.Context, serverName string) (*tls.Config, error) {
|
||||
func CreateTunnelConfig(caCert string, serverName string) (*tls.Config, error) {
|
||||
var rootCAs []string
|
||||
if c.String(CaCertFlag) != "" {
|
||||
rootCAs = append(rootCAs, c.String(CaCertFlag))
|
||||
if caCert != "" {
|
||||
rootCAs = append(rootCAs, caCert)
|
||||
}
|
||||
|
||||
userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
|
||||
|
||||
Reference in New Issue
Block a user