TUN-10390: Fix missing TLS settings
Check / check (1.22.x, macos-latest) (push) Has been cancelled
Check / check (1.22.x, ubuntu-latest) (push) Has been cancelled
Check / check (1.22.x, windows-latest) (push) Has been cancelled
Semgrep config / semgrep/ci (push) Has been cancelled

Fixing missing TLS settings. While developing the pre-check probes, I forgot to add the certificate settings, which are essential for establishing a connection to origintunneld. I discovered this while testing cloudflared locally.
This commit is contained in:
Miguel da Costa Martins Marcelino
2026-05-06 11:17:59 +00:00
parent 7585e38948
commit e8f8b2afb7
8 changed files with 104 additions and 53 deletions
+3
View File
@@ -81,6 +81,9 @@ const (
// EdgeBindAddress is the command line flag to bind to IP address for outgoing connections to Cloudflare Edge // EdgeBindAddress is the command line flag to bind to IP address for outgoing connections to Cloudflare Edge
EdgeBindAddress = "edge-bind-address" EdgeBindAddress = "edge-bind-address"
// CACert Certificate Authority authenticating connections with Cloudflare's edge network.
CACert = "cacert"
// Force is the command line flag to specify if you wish to force an action // Force is the command line flag to specify if you wish to force an action
Force = "force" Force = "force"
+1 -1
View File
@@ -642,7 +642,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: false, Hidden: false,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: tlsconfig.CaCertFlag, Name: cfdflags.CACert,
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.", Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",
EnvVars: []string{"TUNNEL_CACERT"}, EnvVars: []string{"TUNNEL_CACERT"},
Hidden: true, Hidden: true,
+1 -1
View File
@@ -168,7 +168,7 @@ func prepareTunnelConfig(
if tlsSettings == nil { if tlsSettings == nil {
return nil, nil, fmt.Errorf("%s has unknown TLS settings", p) return nil, nil, fmt.Errorf("%s has unknown TLS settings", p)
} }
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName) edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c.String(flags.CACert), tlsSettings.ServerName)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "unable to create TLS config to connect with edge") return nil, nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
} }
+41 -7
View File
@@ -2,6 +2,7 @@ package prechecks
import ( import (
"context" "context"
"fmt"
"slices" "slices"
"time" "time"
@@ -54,7 +55,7 @@ func (tr TransportResults) Collect() []CheckResult {
// //
// Each failed probe is retried up to maxRetries times with exponential backoff. // Each failed probe is retried up to maxRetries times with exponential backoff.
// The suite is bounded by cfg.Timeout (defaultTimeout if zero). // The suite is bounded by cfg.Timeout (defaultTimeout if zero).
func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDialers) Report { func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, runDialers RunDialers) Report {
runID := uuid.New() runID := uuid.New()
if cfg.Timeout <= 0 { if cfg.Timeout <= 0 {
@@ -63,6 +64,10 @@ func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDia
ctx, cancel := context.WithTimeout(ctx, cfg.Timeout) ctx, cancel := context.WithTimeout(ctx, cfg.Timeout)
defer cancel() defer cancel()
// Build TLS configs once per protocol
quicTLSConfig, quicTLSErr := probeTLSConfig(caCert, connection.QUIC)
http2TLSConfig, http2TLSErr := probeTLSConfig(caCert, connection.HTTP2)
// 1) DNS must complete before transport probes know which addresses to dial. // 1) DNS must complete before transport probes know which addresses to dial.
addrGroups, dnsResults := runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region) addrGroups, dnsResults := runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region)
@@ -81,8 +86,8 @@ func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDia
if !dnsOK { if !dnsOK {
// DNS failed: emit one skip row per region so the table stays consistent. // DNS failed: emit one skip row per region so the table stays consistent.
results.QUIC = skipResultsForRegions(dnsResults, ProbeTypeQUIC, "UDP Connectivity") results.QUIC = skipResultsForRegions(dnsResults, ProbeTypeQUIC, componentUDPConnectivity)
results.HTTP2 = skipResultsForRegions(dnsResults, ProbeTypeHTTP2, "TCP Connectivity") results.HTTP2 = skipResultsForRegions(dnsResults, ProbeTypeHTTP2, componentTCPConnectivity)
} else { } else {
perRegionAddrs := addrsByRegion(addrGroups, cfg.IPVersion) perRegionAddrs := addrsByRegion(addrGroups, cfg.IPVersion)
regionTargets := dnsTargets(dnsResults) regionTargets := dnsTargets(dnsResults)
@@ -91,18 +96,30 @@ func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDia
http2Ch := make(chan []CheckResult, 1) http2Ch := make(chan []CheckResult, 1)
go func() { go func() {
quicCh <- probeAllRegions(ctx, ProbeTypeQUIC, "UDP Connectivity", if quicTLSErr != nil {
log.Warn().Err(quicTLSErr).Msg("Failed to build QUIC probe TLS config")
quicCh <- tlsConfigErrResults(ProbeTypeQUIC, componentUDPConnectivity,
regionTargets, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, quicTLSErr), actionQUICBlocked)
return
}
quicCh <- probeAllRegions(ctx, ProbeTypeQUIC, componentUDPConnectivity,
perRegionAddrs, regionTargets, perRegionAddrs, regionTargets,
func(addr *allregions.EdgeAddr) CheckResult { func(addr *allregions.EdgeAddr) CheckResult {
return probeQUIC(ctx, runDialers.QUICDialer, addr, log) return probeQUIC(ctx, quicTLSConfig, runDialers.QUICDialer, addr, log)
}) })
}() }()
go func() { go func() {
http2Ch <- probeAllRegions(ctx, ProbeTypeHTTP2, "TCP Connectivity", if http2TLSErr != nil {
log.Warn().Err(http2TLSErr).Msg("Failed to build HTTP/2 probe TLS config")
http2Ch <- tlsConfigErrResults(ProbeTypeHTTP2, componentTCPConnectivity,
regionTargets, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, http2TLSErr), actionHTTP2Blocked)
return
}
http2Ch <- probeAllRegions(ctx, ProbeTypeHTTP2, componentTCPConnectivity,
perRegionAddrs, regionTargets, perRegionAddrs, regionTargets,
func(addr *allregions.EdgeAddr) CheckResult { func(addr *allregions.EdgeAddr) CheckResult {
return probeHTTP2(ctx, runDialers.TCPDialer, addr) return probeHTTP2(ctx, http2TLSConfig, runDialers.TCPDialer, addr)
}) })
}() }()
@@ -119,6 +136,23 @@ func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDia
} }
} }
// tlsConfigErrResults returns one Fail CheckResult per region target, used when
// TLS config construction fails before any dial is attempted.
func tlsConfigErrResults(probeType ProbeType, component string, regionTargets []string, details, action string) []CheckResult {
results := make([]CheckResult, len(regionTargets))
for i, target := range regionTargets {
results[i] = CheckResult{
Type: probeType,
Component: component,
Target: target,
ProbeStatus: Fail,
Details: details,
Action: action,
}
}
return results
}
func runDNSProbe(ctx context.Context, resolver DNSResolver, region string) ([][]*allregions.EdgeAddr, []CheckResult) { func runDNSProbe(ctx context.Context, resolver DNSResolver, region string) ([][]*allregions.EdgeAddr, []CheckResult) {
var addrGroups [][]*allregions.EdgeAddr var addrGroups [][]*allregions.EdgeAddr
var dnsResults []CheckResult var dnsResults []CheckResult
+16 -12
View File
@@ -19,6 +19,10 @@ import (
"github.com/cloudflare/cloudflared/mocks" "github.com/cloudflare/cloudflared/mocks"
) )
const (
emptyCert = ""
)
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Helpers // Helpers
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -114,7 +118,7 @@ func TestRun_AllPass(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 2 DNS + 2 QUIC + 2 HTTP2 + 1 API = 7 results. // 2 DNS + 2 QUIC + 2 HTTP2 + 1 API = 7 results.
@@ -145,7 +149,7 @@ func TestRun_QUICBlocked(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 2 DNS Pass + 2 QUIC Fail + 2 HTTP2 Pass + 1 API Pass. // 2 DNS Pass + 2 QUIC Fail + 2 HTTP2 Pass + 1 API Pass.
@@ -175,7 +179,7 @@ func TestRun_HTTP2Blocked(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 2 DNS Pass + 2 QUIC Pass + 2 HTTP2 Fail + 1 API Pass. // 2 DNS Pass + 2 QUIC Pass + 2 HTTP2 Fail + 1 API Pass.
@@ -205,7 +209,7 @@ func TestRun_BothTransportsBlocked(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 2 DNS Pass + 2 QUIC Fail + 2 HTTP2 Fail + 1 API Pass. // 2 DNS Pass + 2 QUIC Fail + 2 HTTP2 Fail + 1 API Pass.
@@ -244,7 +248,7 @@ func TestRun_PartialRegionQUICFail(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 2 DNS Pass + QUIC-region1 Pass + QUIC-region2 Fail + 2 HTTP2 Pass + 1 API Pass. // 2 DNS Pass + QUIC-region1 Pass + QUIC-region2 Fail + 2 HTTP2 Pass + 1 API Pass.
@@ -277,7 +281,7 @@ func TestRun_DNSFail_SkipsTransports(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// DNS failure emits 2 Fail rows (one per default region). // DNS failure emits 2 Fail rows (one per default region).
@@ -314,7 +318,7 @@ func TestRun_ManagementAPIFail(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, errors.New("connection refused")).AnyTimes() Return(nil, errors.New("connection refused")).AnyTimes()
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 2 DNS Pass + 2 QUIC Pass + 2 HTTP2 Pass + 1 API Fail. // 2 DNS Pass + 2 QUIC Pass + 2 HTTP2 Pass + 1 API Fail.
@@ -345,7 +349,7 @@ func TestRun_RegionFlagForwardedToDNS(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Region: "us", Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Region: "us", Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// DNS rows carry regional hostnames (indices 0 and 1). // DNS rows carry regional hostnames (indices 0 and 1).
@@ -383,7 +387,7 @@ func TestRun_QUICUsesProbeConnIndex(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
} }
@@ -407,7 +411,7 @@ func TestRun_BothFamiliesProbed(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 2 DNS + 2 QUIC + 2 HTTP2 + 1 API = 7 results, all passing. // 2 DNS + 2 QUIC + 2 HTTP2 + 1 API = 7 results, all passing.
@@ -437,7 +441,7 @@ func TestRun_IPv4OnlySkipsV6(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv4Only}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv4Only},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass)
@@ -464,7 +468,7 @@ func TestRun_IPv6OnlySkipsV4(t *testing.T) {
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil) Return(nopConn{}, nil)
report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv6Only}, report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv6Only},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass)
+23 -16
View File
@@ -17,6 +17,7 @@ import (
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
edgedial "github.com/cloudflare/cloudflared/edgediscovery" edgedial "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/tlsconfig"
) )
const ( const (
@@ -50,6 +51,7 @@ const (
detailsConnectionFailed = "Connection failed" detailsConnectionFailed = "Connection failed"
detailsTCPPortReachable = "TCP port reachable (TLS not validated)" detailsTCPPortReachable = "TCP port reachable (TLS not validated)"
detailsDNSPrerequisiteFailed = "DNS prerequisite failed" detailsDNSPrerequisiteFailed = "DNS prerequisite failed"
detailsTLSConfigFailed = "TLS configuration failed"
// Region hostname templates. // Region hostname templates.
region1Global = "region1.v2.argotunnel.com" region1Global = "region1.v2.argotunnel.com"
@@ -103,6 +105,25 @@ func (d *NetManagementDialer) DialContext(ctx context.Context, network, addr str
return d.Dialer.DialContext(ctx, network, addr) return d.Dialer.DialContext(ctx, network, addr)
} }
// probeTLSConfig builds a *tls.Config for a pre-check probe using the same
// certificate pool as the production tunnel. The SNI and NextProtos are taken from
// p.ProbeTLSSettings() so that the probe SNI is used instead of the production SNI,
// which avoids noisy logs in origintunneld.
func probeTLSConfig(caCert string, p connection.Protocol) (*tls.Config, error) {
settings := p.ProbeTLSSettings()
if settings == nil {
return nil, fmt.Errorf("no probe TLS settings for protocol %s", p)
}
cfg, err := tlsconfig.CreateTunnelConfig(caCert, settings.ServerName)
if err != nil {
return nil, err
}
if len(settings.NextProtos) > 0 {
cfg.NextProtos = settings.NextProtos
}
return cfg, nil
}
// probeDNS resolves edge addresses for the given region via the supplied // probeDNS resolves edge addresses for the given region via the supplied
// DNSResolver and returns a CheckResult for each region discovered. If // DNSResolver and returns a CheckResult for each region discovered. If
// resolution fails for all regions, every result will carry StatusFail. // resolution fails for all regions, every result will carry StatusFail.
@@ -153,6 +174,7 @@ func probeDNS(
// budget. // budget.
func probeQUIC( func probeQUIC(
ctx context.Context, ctx context.Context,
tlsConfig *tls.Config,
dialer QUICDialer, dialer QUICDialer,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
logger *zerolog.Logger, logger *zerolog.Logger,
@@ -160,14 +182,6 @@ func probeQUIC(
dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout) dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout)
defer cancel() defer cancel()
tlsSettings := connection.QUIC.ProbeTLSSettings()
tlsConfig := &tls.Config{
ServerName: tlsSettings.ServerName,
NextProtos: tlsSettings.NextProtos,
MinVersion: tls.VersionTLS13,
CurvePreferences: []tls.CurveID{tls.CurveP256},
}
// We call dialer.DialQuic with isProbe = true, which bypasses connIndex check. // We call dialer.DialQuic with isProbe = true, which bypasses connIndex check.
// Therefore, whatever we add to connIndex will not be relevant. // Therefore, whatever we add to connIndex will not be relevant.
edgeAddrPort := addr.UDP.AddrPort() edgeAddrPort := addr.UDP.AddrPort()
@@ -213,14 +227,7 @@ func probeQUIC(
// //
// The dial timeout is capped at perProbeDialTimeout so that a single blocked // The dial timeout is capped at perProbeDialTimeout so that a single blocked
// dial cannot exhaust the entire suite budget. // dial cannot exhaust the entire suite budget.
func probeHTTP2(ctx context.Context, dialer TCPDialer, addr *allregions.EdgeAddr) CheckResult { func probeHTTP2(ctx context.Context, tlsConfig *tls.Config, dialer TCPDialer, addr *allregions.EdgeAddr) CheckResult {
tlsSettings := connection.HTTP2.ProbeTLSSettings()
tlsConfig := &tls.Config{
ServerName: tlsSettings.ServerName,
MinVersion: tls.VersionTLS12,
CurvePreferences: []tls.CurveID{tls.CurveP256},
}
conn, err := dialer.DialEdge(ctx, perProbeDialTimeout, tlsConfig, addr.TCP, nil) conn, err := dialer.DialEdge(ctx, perProbeDialTimeout, tlsConfig, addr.TCP, nil)
if err != nil { if err != nil {
return CheckResult{ return CheckResult{
+13 -10
View File
@@ -2,6 +2,7 @@ package prechecks
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"net" "net"
"testing" "testing"
@@ -24,12 +25,14 @@ const (
testRegion2US = region2US testRegion2US = region2US
testRegion1Fed = region1Fed testRegion1Fed = region1Fed
testRegion2Fed = region2Fed testRegion2Fed = region2Fed
testRegion1EU = "eu-region1.v2.argotunnel.com"
testRegion2EU = "eu-region2.v2.argotunnel.com"
testEdgePort = 7844 testEdgePort = 7844
) )
// testTLSConfig is a minimal *tls.Config for tests. Mock dialers never
// perform a real TLS handshake, so an empty config is sufficient.
var testTLSConfig = &tls.Config{} //nolint:gosec
// mockQuicConnection is a minimal test double for quic.Connection. // mockQuicConnection is a minimal test double for quic.Connection.
type mockQuicConnection struct { type mockQuicConnection struct {
closeErr error closeErr error
@@ -231,7 +234,7 @@ func TestProbeQUIC_Success(t *testing.T) {
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil) logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger) result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
assert.Equal(t, ProbeTypeQUIC, result.Type) assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Pass, result.ProbeStatus) assert.Equal(t, Pass, result.ProbeStatus)
@@ -249,7 +252,7 @@ func TestProbeQUIC_DialError(t *testing.T) {
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil) logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger) result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
assert.Equal(t, ProbeTypeQUIC, result.Type) assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Fail, result.ProbeStatus) assert.Equal(t, Fail, result.ProbeStatus)
@@ -269,7 +272,7 @@ func TestProbeQUIC_CloseErrorDoesNotAffectResult(t *testing.T) {
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil) logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger) result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
assert.Equal(t, ProbeTypeQUIC, result.Type) assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Pass, result.ProbeStatus) assert.Equal(t, Pass, result.ProbeStatus)
@@ -287,7 +290,7 @@ func TestProbeQUIC_ContextTimeout(t *testing.T) {
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil) logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger) result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
assert.Equal(t, Fail, result.ProbeStatus) assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsHandshakeFailed, result.Details) assert.Equal(t, detailsHandshakeFailed, result.Details)
@@ -305,7 +308,7 @@ func TestProbeHTTP2_Success(t *testing.T) {
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
result := probeHTTP2(context.Background(), dialer, addr) result := probeHTTP2(context.Background(), testTLSConfig, dialer, addr)
assert.Equal(t, ProbeTypeHTTP2, result.Type) assert.Equal(t, ProbeTypeHTTP2, result.Type)
assert.Equal(t, Pass, result.ProbeStatus) assert.Equal(t, Pass, result.ProbeStatus)
@@ -322,7 +325,7 @@ func TestProbeHTTP2_DialError(t *testing.T) {
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
result := probeHTTP2(context.Background(), dialer, addr) result := probeHTTP2(context.Background(), testTLSConfig, dialer, addr)
assert.Equal(t, ProbeTypeHTTP2, result.Type) assert.Equal(t, ProbeTypeHTTP2, result.Type)
assert.Equal(t, Fail, result.ProbeStatus) assert.Equal(t, Fail, result.ProbeStatus)
@@ -512,7 +515,7 @@ func TestProbeQUIC_IPv6Address(t *testing.T) {
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6) addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
logger := zerolog.New(nil) logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger) result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
assert.Equal(t, Pass, result.ProbeStatus) assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsHandshakeSuccessful, result.Details) assert.Equal(t, detailsHandshakeSuccessful, result.Details)
@@ -530,7 +533,7 @@ func TestProbeHTTP2_IPv6Address(t *testing.T) {
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6) addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
result := probeHTTP2(context.Background(), dialer, addr) result := probeHTTP2(context.Background(), testTLSConfig, dialer, addr)
assert.Equal(t, Pass, result.ProbeStatus) assert.Equal(t, Pass, result.ProbeStatus)
} }
+6 -6
View File
@@ -11,12 +11,10 @@ import (
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/urfave/cli/v2"
) )
const ( const (
OriginCAPoolFlag = "origin-ca-pool" OriginCAPoolFlag = "origin-ca-pool"
CaCertFlag = "cacert"
) )
// CertReloader can load and reload a TLS certificate from a particular filepath. // CertReloader can load and reload a TLS certificate from a particular filepath.
@@ -65,7 +63,7 @@ func (cr *CertReloader) LoadCert() error {
// Keep the old certificate if there's a problem reading the new one. // Keep the old certificate if there's a problem reading the new one.
if err != nil { if err != nil {
sentry.CaptureException(fmt.Errorf("Error parsing X509 key pair: %v", err)) sentry.CaptureException(fmt.Errorf("error parsing X509 key pair: %v", err))
return err return err
} }
cr.certificate = &cert cr.certificate = &cert
@@ -77,6 +75,7 @@ func LoadOriginCA(originCAPoolFilename string, log *zerolog.Logger) (*x509.CertP
if originCAPoolFilename != "" { if originCAPoolFilename != "" {
var err error var err error
// nolint:gosec
originCustomCAPool, err = os.ReadFile(originCAPoolFilename) originCustomCAPool, err = os.ReadFile(originCAPoolFilename)
if err != nil { if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag)) return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag))
@@ -116,6 +115,7 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
return certPool, nil return certPool, nil
} }
// nolint: gosec
customOriginCA, err := os.ReadFile(originCAFilename) customOriginCA, err := os.ReadFile(originCAFilename)
if err != nil { if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", originCAFilename)) return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", originCAFilename))
@@ -127,10 +127,10 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
return certPool, nil return certPool, nil
} }
func CreateTunnelConfig(c *cli.Context, serverName string) (*tls.Config, error) { func CreateTunnelConfig(caCert string, serverName string) (*tls.Config, error) {
var rootCAs []string var rootCAs []string
if c.String(CaCertFlag) != "" { if caCert != "" {
rootCAs = append(rootCAs, c.String(CaCertFlag)) rootCAs = append(rootCAs, caCert)
} }
userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName} userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}