TUN-10388 Implement dialers for connectivity checks

This PR implements all the dialers and resolvers needed to make pre-checks happen. So this task focuses on the following:

1. Implement the DNS probe: call DNSResolver.Resolve(region)
2. Implement the QUIC probe: call QUICDialer.DialQuic (handshake only, no stream opened) and record the result.
3. Implement the HTTP/2 probe: call TCPDialer.DialEdge (TCP + TLS handshake only, no frames sent) and record the result.
4. Implement the Management API probe: call ManagementDialer.DialContext to api.cloudflare.com:443 and record the result.
5. Export edgeDiscovery as EdgeDiscovery in edgediscovery/allregions/discovery.go so the pre-check can reuse the production DNS path.

This sets up the main components to implement the checker.
This commit is contained in:
Miguel da Costa Martins Marcelino
2026-04-30 15:15:25 +00:00
parent a0401df621
commit 9978cfd0d5
10 changed files with 1171 additions and 17 deletions
+1
View File
@@ -912,6 +912,7 @@ and virtualized host network stacks from each other`,
}
func configureProxyFlags(shouldHide bool) []cli.Flag {
//nolint: prealloc
flags := []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
+1 -1
View File
@@ -109,7 +109,7 @@ var friendlyDNSErrorLines = []string{
}
// EdgeDiscovery implements HA service discovery lookup.
func edgeDiscovery(log *zerolog.Logger, srvService string) ([][]*EdgeAddr, error) {
func EdgeDiscovery(log *zerolog.Logger, srvService string) ([][]*EdgeAddr, error) {
logger := log.With().Int(management.EventTypeKey, int(management.Cloudflared)).Logger()
logger.Debug().
Int(management.EventTypeKey, int(management.Cloudflared)).
+3 -2
View File
@@ -6,6 +6,7 @@ import (
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func (ea *EdgeAddr) String() string {
@@ -25,8 +26,8 @@ func TestEdgeDiscovery(t *testing.T) {
}
l := zerolog.Nop()
addrLists, err := edgeDiscovery(&l, "")
assert.NoError(t, err)
addrLists, err := EdgeDiscovery(&l, "")
require.NoError(t, err)
actualAddrSet := map[string]bool{}
for _, addrs := range addrLists {
for _, addr := range addrs {
+8 -5
View File
@@ -20,7 +20,7 @@ type Regions struct {
// ResolveEdge resolves the Cloudflare edge, returning all regions discovered.
func ResolveEdge(log *zerolog.Logger, region string, overrideIPVersion ConfigIPVersion) (*Regions, error) {
edgeAddrs, err := edgeDiscovery(log, getRegionalServiceName(region))
edgeAddrs, err := EdgeDiscovery(log, RegionalServiceName(region))
if err != nil {
return nil, err
}
@@ -91,6 +91,7 @@ func (rs *Regions) GetUnusedAddr(excluding *EdgeAddr, connID int) *EdgeAddr {
// evenly across both regions.
if rs.region1.AvailableAddrs() == rs.region2.AvailableAddrs() {
regions := []Region{rs.region1, rs.region2}
//nolint:gosec
firstChoice := rand.Intn(2)
return getAddrs(excluding, connID, &regions[firstChoice], &regions[1-firstChoice])
}
@@ -131,11 +132,13 @@ func (rs *Regions) GiveBack(addr *EdgeAddr, hasConnectivityError bool) bool {
return rs.region2.GiveBack(addr, hasConnectivityError)
}
// Return regionalized service name if `region` isn't empty, otherwise return the global service name for origintunneld
func getRegionalServiceName(region string) string {
// RegionalServiceName returns the SRV service name for the given region.
// When region is empty it returns the global service name ("v2-origintunneld").
// Otherwise, it prepends the region, e.g. "us-v2-origintunneld".
func RegionalServiceName(region string) string {
if region != "" {
return region + "-" + srvService // Example: `us-v2-origintunneld`
return region + "-" + srvService
}
return srvService // Global service is just `v2-origintunneld`
return srvService
}
+4 -6
View File
@@ -237,21 +237,19 @@ func TestNewNoResolveBalancesRegions(t *testing.T) {
}
}
func TestGetRegionalServiceName(t *testing.T) {
func TestRegionalServiceName(t *testing.T) {
// Empty region should just go to origintunneld
globalServiceName := getRegionalServiceName("")
assert.Equal(t, srvService, globalServiceName)
assert.Equal(t, srvService, RegionalServiceName(""))
// Non-empty region should go to the regional origintunneld variant
for _, region := range []string{"us", "pt", "am"} {
regionalServiceName := getRegionalServiceName(region)
assert.Equal(t, region+"-"+srvService, regionalServiceName)
assert.Equal(t, region+"-"+srvService, RegionalServiceName(region))
}
}
func RegionsIsBalanced(t *testing.T, rs *Regions) {
delta := rs.region1.AvailableAddrs() - rs.region2.AvailableAddrs()
assert.True(t, abs(delta) <= 1)
assert.LessOrEqual(t, abs(delta), 1)
}
func abs(x int) int {
+278
View File
@@ -0,0 +1,278 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ../prechecks/resolvers.go
//
// Generated by this command:
//
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_resolvers.go -source=../prechecks/resolvers.go
//
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
tls "crypto/tls"
net "net"
netip "net/netip"
reflect "reflect"
time "time"
quic "github.com/quic-go/quic-go"
zerolog "github.com/rs/zerolog"
gomock "go.uber.org/mock/gomock"
dialopts "github.com/cloudflare/cloudflared/connection/dialopts"
allregions "github.com/cloudflare/cloudflared/edgediscovery/allregions"
)
// MockDNSResolver is a mock of DNSResolver interface.
type MockDNSResolver struct {
ctrl *gomock.Controller
recorder *MockDNSResolverMockRecorder
isgomock struct{}
}
// MockDNSResolverMockRecorder is the mock recorder for MockDNSResolver.
type MockDNSResolverMockRecorder struct {
mock *MockDNSResolver
}
// NewMockDNSResolver creates a new mock instance.
func NewMockDNSResolver(ctrl *gomock.Controller) *MockDNSResolver {
mock := &MockDNSResolver{ctrl: ctrl}
mock.recorder = &MockDNSResolverMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDNSResolver) EXPECT() *MockDNSResolverMockRecorder {
return m.recorder
}
// Resolve mocks base method.
func (m *MockDNSResolver) Resolve(region string) ([][]*allregions.EdgeAddr, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Resolve", region)
ret0, _ := ret[0].([][]*allregions.EdgeAddr)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Resolve indicates an expected call of Resolve.
func (mr *MockDNSResolverMockRecorder) Resolve(region any) *MockDNSResolverResolveCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockDNSResolver)(nil).Resolve), region)
return &MockDNSResolverResolveCall{Call: call}
}
// MockDNSResolverResolveCall wrap *gomock.Call
type MockDNSResolverResolveCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockDNSResolverResolveCall) Return(arg0 [][]*allregions.EdgeAddr, arg1 error) *MockDNSResolverResolveCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockDNSResolverResolveCall) Do(f func(string) ([][]*allregions.EdgeAddr, error)) *MockDNSResolverResolveCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockDNSResolverResolveCall) DoAndReturn(f func(string) ([][]*allregions.EdgeAddr, error)) *MockDNSResolverResolveCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// MockTCPDialer is a mock of TCPDialer interface.
type MockTCPDialer struct {
ctrl *gomock.Controller
recorder *MockTCPDialerMockRecorder
isgomock struct{}
}
// MockTCPDialerMockRecorder is the mock recorder for MockTCPDialer.
type MockTCPDialerMockRecorder struct {
mock *MockTCPDialer
}
// NewMockTCPDialer creates a new mock instance.
func NewMockTCPDialer(ctrl *gomock.Controller) *MockTCPDialer {
mock := &MockTCPDialer{ctrl: ctrl}
mock.recorder = &MockTCPDialerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTCPDialer) EXPECT() *MockTCPDialerMockRecorder {
return m.recorder
}
// DialEdge mocks base method.
func (m *MockTCPDialer) DialEdge(ctx context.Context, timeout time.Duration, tlsConfig *tls.Config, addr *net.TCPAddr, localIP net.IP) (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialEdge", ctx, timeout, tlsConfig, addr, localIP)
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialEdge indicates an expected call of DialEdge.
func (mr *MockTCPDialerMockRecorder) DialEdge(ctx, timeout, tlsConfig, addr, localIP any) *MockTCPDialerDialEdgeCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialEdge", reflect.TypeOf((*MockTCPDialer)(nil).DialEdge), ctx, timeout, tlsConfig, addr, localIP)
return &MockTCPDialerDialEdgeCall{Call: call}
}
// MockTCPDialerDialEdgeCall wrap *gomock.Call
type MockTCPDialerDialEdgeCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockTCPDialerDialEdgeCall) Return(arg0 net.Conn, arg1 error) *MockTCPDialerDialEdgeCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockTCPDialerDialEdgeCall) Do(f func(context.Context, time.Duration, *tls.Config, *net.TCPAddr, net.IP) (net.Conn, error)) *MockTCPDialerDialEdgeCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockTCPDialerDialEdgeCall) DoAndReturn(f func(context.Context, time.Duration, *tls.Config, *net.TCPAddr, net.IP) (net.Conn, error)) *MockTCPDialerDialEdgeCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// MockQUICDialer is a mock of QUICDialer interface.
type MockQUICDialer struct {
ctrl *gomock.Controller
recorder *MockQUICDialerMockRecorder
isgomock struct{}
}
// MockQUICDialerMockRecorder is the mock recorder for MockQUICDialer.
type MockQUICDialerMockRecorder struct {
mock *MockQUICDialer
}
// NewMockQUICDialer creates a new mock instance.
func NewMockQUICDialer(ctrl *gomock.Controller) *MockQUICDialer {
mock := &MockQUICDialer{ctrl: ctrl}
mock.recorder = &MockQUICDialerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockQUICDialer) EXPECT() *MockQUICDialerMockRecorder {
return m.recorder
}
// DialQuic mocks base method.
func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.Connection, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialQuic", ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
ret0, _ := ret[0].(quic.Connection)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialQuic indicates an expected call of DialQuic.
func (mr *MockQUICDialerMockRecorder) DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts any) *MockQUICDialerDialQuicCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialQuic", reflect.TypeOf((*MockQUICDialer)(nil).DialQuic), ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
return &MockQUICDialerDialQuicCall{Call: call}
}
// MockQUICDialerDialQuicCall wrap *gomock.Call
type MockQUICDialerDialQuicCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.Connection, arg1 error) *MockQUICDialerDialQuicCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// MockManagementDialer is a mock of ManagementDialer interface.
type MockManagementDialer struct {
ctrl *gomock.Controller
recorder *MockManagementDialerMockRecorder
isgomock struct{}
}
// MockManagementDialerMockRecorder is the mock recorder for MockManagementDialer.
type MockManagementDialerMockRecorder struct {
mock *MockManagementDialer
}
// NewMockManagementDialer creates a new mock instance.
func NewMockManagementDialer(ctrl *gomock.Controller) *MockManagementDialer {
mock := &MockManagementDialer{ctrl: ctrl}
mock.recorder = &MockManagementDialerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManagementDialer) EXPECT() *MockManagementDialerMockRecorder {
return m.recorder
}
// DialContext mocks base method.
func (m *MockManagementDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialContext", ctx, network, addr)
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialContext indicates an expected call of DialContext.
func (mr *MockManagementDialerMockRecorder) DialContext(ctx, network, addr any) *MockManagementDialerDialContextCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockManagementDialer)(nil).DialContext), ctx, network, addr)
return &MockManagementDialerDialContextCall{Call: call}
}
// MockManagementDialerDialContextCall wrap *gomock.Call
type MockManagementDialerDialContextCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockManagementDialerDialContextCall) Return(arg0 net.Conn, arg1 error) *MockManagementDialerDialContextCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockManagementDialerDialContextCall) Do(f func(context.Context, string, string) (net.Conn, error)) *MockManagementDialerDialContextCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockManagementDialerDialContextCall) DoAndReturn(f func(context.Context, string, string) (net.Conn, error)) *MockManagementDialerDialContextCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
+2
View File
@@ -3,3 +3,5 @@
package mocks
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_resolvers.go -source=../prechecks/resolvers.go"
+338
View File
@@ -0,0 +1,338 @@
package prechecks
import (
"context"
"crypto/tls"
"fmt"
"math"
"net"
"net/netip"
"time"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection/dialopts"
"github.com/cloudflare/cloudflared/connection"
edgedial "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
)
const (
perProbeDialTimeout = 5 * time.Second
// Action messages for each probe outcome.
actionDNSFail = "Ensure your DNS resolver can resolve '%s'. Run: dig A %s @1.1.1.1. If that fails, contact your network administrator."
actionQUICBlocked = "QUIC traffic failed to connect to port 7844."
actionHTTP2Blocked = "Allow outbound TCP on port 7844."
actionAPIUnreachable = "cloudflared will still run, but automatic software updates are unavailable. " +
"Ensure port 443 TCP to api.cloudflare.com is open if you want auto-updates."
// Component names for CheckResult.
componentDNSResolution = "DNS Resolution"
componentUDPConnectivity = "UDP Connectivity"
componentTCPConnectivity = "TCP Connectivity"
componentCloudflareAPI = "Cloudflare API"
// Target identifiers for CheckResult.
targetPortQUIC = "Port 7844 (QUIC)"
targetPortHTTP2 = "Port 7844 (HTTP/2)"
targetAPI = "api.cloudflare.com:443"
// Details messages for CheckResult.
detailsNoAddressesReturned = "No addresses returned"
detailsResolvedSuccessfully = "Resolved successfully"
detailsHandshakeFailed = "Handshake failed"
detailsHandshakeSuccessful = "Handshake successful"
detailsBlockedOrUnreachable = "Blocked or unreachable"
detailsTLSHandshakeSuccessful = "TLS handshake successful"
detailsConnectionFailed = "Connection failed"
detailsTCPPortReachable = "TCP port reachable (TLS not validated)"
detailsDNSPrerequisiteFailed = "DNS prerequisite failed"
// Region hostname templates.
region1Global = "region1.v2.argotunnel.com"
region2Global = "region2.v2.argotunnel.com"
region1US = "us-region1.v2.argotunnel.com"
region2US = "us-region2.v2.argotunnel.com"
region1Fed = "fed-region1.v2.argotunnel.com"
region2Fed = "fed-region2.v2.argotunnel.com"
)
type EdgeDNSResolver struct {
Log *zerolog.Logger
}
func (r *EdgeDNSResolver) Resolve(region string) ([][]*allregions.EdgeAddr, error) {
return allregions.EdgeDiscovery(r.Log, allregions.RegionalServiceName(region))
}
type EdgeTCPDialer struct{}
func (d *EdgeTCPDialer) DialEdge(
ctx context.Context,
timeout time.Duration,
tlsConfig *tls.Config,
addr *net.TCPAddr,
localIP net.IP,
) (net.Conn, error) {
return edgedial.DialEdge(ctx, timeout, tlsConfig, addr, localIP)
}
type EdgeQUICDialer struct{}
func (d *EdgeQUICDialer) DialQuic(
ctx context.Context,
quicConfig *quic.Config,
tlsConfig *tls.Config,
addr netip.AddrPort,
localAddr net.IP,
connIndex uint8,
logger *zerolog.Logger,
opts dialopts.DialOpts,
) (quic.Connection, error) {
return connection.DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
}
type NetManagementDialer struct {
Dialer net.Dialer
}
func (d *NetManagementDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, addr)
}
// probeDNS resolves edge addresses for the given region via the supplied
// DNSResolver and returns a CheckResult for each region discovered. If
// resolution fails for all regions, every result will carry StatusFail.
func probeDNS(
resolver DNSResolver,
region string,
) ([][]*allregions.EdgeAddr, []CheckResult) {
addrGroups, err := resolver.Resolve(region)
if err != nil || len(addrGroups) == 0 {
detail := detailsNoAddressesReturned
if err != nil {
detail = err.Error()
}
region1Target, region2Target := regionTargets(region)
return nil, []CheckResult{
{
Type: ProbeTypeDNS,
Component: componentDNSResolution,
Target: region1Target,
ProbeStatus: Fail,
Details: detail,
Action: fmt.Sprintf(actionDNSFail, region1Target, region1Target),
},
{
Type: ProbeTypeDNS,
Component: componentDNSResolution,
Target: region2Target,
ProbeStatus: Fail,
Details: detail,
Action: fmt.Sprintf(actionDNSFail, region2Target, region2Target),
},
}
}
region1Target, region2Target := regionTargets(region)
targets := []string{region1Target, region2Target}
results := make([]CheckResult, 0, len(addrGroups))
for i, group := range addrGroups {
target := fmt.Sprintf("region%d.v2.argotunnel.com", i+1)
if i < len(targets) {
target = targets[i]
}
if len(group) == 0 {
results = append(results, CheckResult{
Type: ProbeTypeDNS,
Component: componentDNSResolution,
Target: target,
ProbeStatus: Fail,
Details: detailsNoAddressesReturned,
Action: fmt.Sprintf(actionDNSFail, target, target),
})
} else {
results = append(results, CheckResult{
Type: ProbeTypeDNS,
Component: componentDNSResolution,
Target: target,
ProbeStatus: Pass,
Details: detailsResolvedSuccessfully,
})
}
}
return addrGroups, results
}
// probeQUIC performs a QUIC handshake to a single edge address and returns a
// CheckResult. The connection is closed immediately after the handshake no
// streams are opened and no RPC frames are sent to avoid triggering the OTD
// registration timeout (TUN-6732). The probe SNI (probe.cftunnel.com) is used
// instead of the production quic.cftunnel.com to prevent OTD log noise.
//
// A per-probe deadline (perProbeDialTimeout) is applied on top of the parent
// context so that a single blocked handshake cannot consume the entire suite
// budget.
func probeQUIC(
ctx context.Context,
dialer QUICDialer,
addr *allregions.EdgeAddr,
logger *zerolog.Logger,
) CheckResult {
dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout)
defer cancel()
tlsSettings := connection.QUIC.ProbeTLSSettings()
tlsConfig := &tls.Config{
ServerName: tlsSettings.ServerName,
NextProtos: tlsSettings.NextProtos,
MinVersion: tls.VersionTLS13,
CurvePreferences: []tls.CurveID{tls.CurveP256},
}
// We call dialer.DialQuic with isProbe = true, which bypasses connIndex check.
// Therefore, whatever we add to connIndex will not be relevant.
edgeAddrPort := addr.UDP.AddrPort()
conn, err := dialer.DialQuic(
dialCtx,
&quic.Config{},
tlsConfig,
edgeAddrPort,
nil,
math.MaxUint8,
logger,
dialopts.DialOpts{SkipPortReuse: true},
)
if err != nil {
return CheckResult{
Type: ProbeTypeQUIC,
Component: componentUDPConnectivity,
Target: targetPortQUIC,
ProbeStatus: Fail,
Details: detailsHandshakeFailed,
Action: actionQUICBlocked,
}
}
if err := conn.CloseWithError(0, "precheck complete"); err != nil {
logger.Debug().Err(err).Msg("Failed to close QUIC connection after successful handshake")
}
return CheckResult{
Type: ProbeTypeQUIC,
Component: componentUDPConnectivity,
Target: targetPortQUIC,
ProbeStatus: Pass,
Details: detailsHandshakeSuccessful,
}
}
// probeHTTP2 performs a TCP + TLS handshake to a single edge address and
// returns a CheckResult. The connection is closed immediately after the
// handshake no HTTP/2 frames are sent to keep the probe minimal. The probe
// SNI (probe.cftunnel.com) is used instead of the production h2.cftunnel.com
// to prevent OTD log noise.
//
// The dial timeout is capped at perProbeDialTimeout so that a single blocked
// dial cannot exhaust the entire suite budget.
func probeHTTP2(ctx context.Context, dialer TCPDialer, addr *allregions.EdgeAddr) CheckResult {
tlsSettings := connection.HTTP2.ProbeTLSSettings()
tlsConfig := &tls.Config{
ServerName: tlsSettings.ServerName,
MinVersion: tls.VersionTLS12,
CurvePreferences: []tls.CurveID{tls.CurveP256},
}
conn, err := dialer.DialEdge(ctx, perProbeDialTimeout, tlsConfig, addr.TCP, nil)
if err != nil {
return CheckResult{
Type: ProbeTypeHTTP2,
Component: componentTCPConnectivity,
Target: targetPortHTTP2,
ProbeStatus: Fail,
Details: detailsBlockedOrUnreachable,
Action: actionHTTP2Blocked,
}
}
_ = conn.Close()
return CheckResult{
Type: ProbeTypeHTTP2,
Component: componentTCPConnectivity,
Target: targetPortHTTP2,
ProbeStatus: Pass,
Details: detailsTLSHandshakeSuccessful,
}
}
// probeManagementAPI tests TCP connectivity to api.cloudflare.com:443. A
// successful TCP connection (no TLS handshake required) confirms the port is
// reachable. This probe is always a soft failure: the tunnel can run without
// it, but automatic software updates will be unavailable.
func probeManagementAPI(ctx context.Context, dialer ManagementDialer) CheckResult {
dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout)
defer cancel()
conn, err := dialer.DialContext(dialCtx, "tcp", targetAPI)
if err != nil {
return CheckResult{
Type: ProbeTypeManagementAPI,
Component: componentCloudflareAPI,
Target: targetAPI,
ProbeStatus: Fail,
Details: detailsConnectionFailed,
Action: actionAPIUnreachable,
}
}
_ = conn.Close()
return CheckResult{
Type: ProbeTypeManagementAPI,
Component: componentCloudflareAPI,
Target: targetAPI,
ProbeStatus: Pass,
Details: detailsTCPPortReachable,
}
}
func skipResult(probeType ProbeType, component, target string) CheckResult {
return CheckResult{
Type: probeType,
Component: component,
Target: target,
ProbeStatus: Skip,
Details: detailsDNSPrerequisiteFailed,
}
}
// regionTargets returns the human-readable hostnames for region1 and region2
// based on the optional region flag value.
func regionTargets(region string) (string, string) {
switch region {
case "us":
return region1US, region2US
case "fed":
return region1Fed, region2Fed
default:
return region1Global, region2Global
}
}
// addrsByFamily extracts one V4 and one V6 address from a resolved CNAME group
// using allregions.NewRegion so that the IP-version preference logic matches
// production exactly. When cfg.IPVersion restricts to a single family the
// excluded family's pointer is nil.
func addrsByFamily(group []*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) (v4, v6 *allregions.EdgeAddr) {
if ipVersion != allregions.IPv6Only {
v4 = allregions.NewRegion(group, allregions.IPv4Only).GetAnyAddress()
}
if ipVersion != allregions.IPv4Only {
v6 = allregions.NewRegion(group, allregions.IPv6Only).GetAnyAddress()
}
return
}
+536
View File
@@ -0,0 +1,536 @@
package prechecks
import (
"context"
"errors"
"net"
"testing"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/mocks"
)
// Test constants for repeated string values.
const (
testRegion1Global = region1Global
testRegion2Global = region2Global
testRegion1US = region1US
testRegion2US = region2US
testRegion1Fed = region1Fed
testRegion2Fed = region2Fed
testRegion1EU = "eu-region1.v2.argotunnel.com"
testRegion2EU = "eu-region2.v2.argotunnel.com"
testEdgePort = 7844
)
// mockQuicConnection is a minimal test double for quic.Connection.
type mockQuicConnection struct {
closeErr error
}
func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStream() (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) LocalAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) RemoteAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, _ string) error {
return m.closeErr
}
func (m *mockQuicConnection) Context() context.Context {
return context.Background()
}
func (m *mockQuicConnection) ConnectionState() quic.ConnectionState {
return quic.ConnectionState{}
}
func (m *mockQuicConnection) SendDatagram(_ []byte) error {
return nil
}
func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) {
return nil, nil
}
func (m *mockQuicConnection) AddPath(*quic.Transport) (*quic.Path, error) {
return nil, nil
}
// Helper to create test edge addresses.
func createTestEdgeAddr(ip string, port int, version allregions.EdgeIPVersion) *allregions.EdgeAddr {
parsedIP := net.ParseIP(ip)
return &allregions.EdgeAddr{
TCP: &net.TCPAddr{IP: parsedIP, Port: port},
UDP: &net.UDPAddr{IP: parsedIP, Port: port},
IPVersion: version,
}
}
// probeDNS tests.
func TestProbeDNS_Success(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
v6Addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr, v6Addr}}, nil)
addrs, results := probeDNS(resolver, "")
require.NotNil(t, addrs)
require.Len(t, results, 1)
assert.Len(t, addrs, 1)
assert.Equal(t, ProbeTypeDNS, results[0].Type)
assert.Equal(t, testRegion1Global, results[0].Target)
assert.Equal(t, Pass, results[0].ProbeStatus)
assert.Equal(t, detailsResolvedSuccessfully, results[0].Details)
}
func TestProbeDNS_MultipleRegions(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
v4Addr1 := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
v4Addr2 := createTestEdgeAddr("192.0.2.2", testEdgePort, allregions.V4)
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr1}, {v4Addr2}}, nil)
addrs, results := probeDNS(resolver, "")
require.NotNil(t, addrs)
require.Len(t, results, 2)
assert.Len(t, addrs, 2)
assert.Equal(t, testRegion1Global, results[0].Target)
assert.Equal(t, Pass, results[0].ProbeStatus)
assert.Equal(t, testRegion2Global, results[1].Target)
assert.Equal(t, Pass, results[1].ProbeStatus)
}
func TestProbeDNS_ResolverError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return(nil, errors.New("DNS lookup failed"))
addrs, results := probeDNS(resolver, "")
assert.Nil(t, addrs)
require.Len(t, results, 2)
assert.Equal(t, Fail, results[0].ProbeStatus)
assert.Equal(t, "DNS lookup failed", results[0].Details)
assert.Contains(t, results[0].Action, testRegion1Global)
assert.Contains(t, results[1].Action, testRegion2Global)
assert.Equal(t, Fail, results[1].ProbeStatus)
}
func TestProbeDNS_EmptyResults(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{}, nil)
addrs, results := probeDNS(resolver, "")
assert.Nil(t, addrs)
require.Len(t, results, 2)
assert.Equal(t, Fail, results[0].ProbeStatus)
assert.Equal(t, "No addresses returned", results[0].Details)
}
func TestProbeDNS_EmptyGroup(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{}}, nil)
addrs, results := probeDNS(resolver, "")
require.NotNil(t, addrs)
require.Len(t, results, 1)
assert.Equal(t, Fail, results[0].ProbeStatus)
assert.Equal(t, "No addresses returned", results[0].Details)
}
func TestProbeDNS_RegionFlag(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("us").Return([][]*allregions.EdgeAddr{{v4Addr}}, nil)
_, results := probeDNS(resolver, "us")
require.Len(t, results, 1)
assert.Equal(t, testRegion1US, results[0].Target)
}
// probeQUIC tests.
func TestProbeQUIC_Success(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockConn := &mockQuicConnection{}
dialer := mocks.NewMockQUICDialer(ctrl)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger)
assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
}
func TestProbeQUIC_DialError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mocks.NewMockQUICDialer(ctrl)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("connection refused"))
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger)
assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsHandshakeFailed, result.Details)
assert.Equal(t, actionQUICBlocked, result.Action)
}
func TestProbeQUIC_CloseErrorDoesNotAffectResult(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockConn := &mockQuicConnection{closeErr: errors.New("close failed")}
dialer := mocks.NewMockQUICDialer(ctrl)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger)
assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
}
func TestProbeQUIC_ContextTimeout(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mocks.NewMockQUICDialer(ctrl)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, context.DeadlineExceeded)
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger)
assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsHandshakeFailed, result.Details)
}
// probeHTTP2 tests.
func TestProbeHTTP2_Success(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mocks.NewMockTCPDialer(ctrl)
dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&net.TCPConn{}, nil)
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
result := probeHTTP2(context.Background(), dialer, addr)
assert.Equal(t, ProbeTypeHTTP2, result.Type)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsTLSHandshakeSuccessful, result.Details)
}
func TestProbeHTTP2_DialError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mocks.NewMockTCPDialer(ctrl)
dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("connection refused"))
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
result := probeHTTP2(context.Background(), dialer, addr)
assert.Equal(t, ProbeTypeHTTP2, result.Type)
assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsBlockedOrUnreachable, result.Details)
assert.Equal(t, actionHTTP2Blocked, result.Action)
}
// probeManagementAPI tests.
func TestProbeManagementAPI_Success(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mocks.NewMockManagementDialer(ctrl)
dialer.EXPECT().DialContext(gomock.Any(), "tcp", "api.cloudflare.com:443").Return(&net.TCPConn{}, nil)
result := probeManagementAPI(context.Background(), dialer)
assert.Equal(t, ProbeTypeManagementAPI, result.Type)
assert.Equal(t, "Cloudflare API", result.Component)
assert.Equal(t, "api.cloudflare.com:443", result.Target)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsTCPPortReachable, result.Details)
}
func TestProbeManagementAPI_DialError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mocks.NewMockManagementDialer(ctrl)
dialer.EXPECT().DialContext(gomock.Any(), "tcp", "api.cloudflare.com:443").Return(nil, errors.New("connection refused"))
result := probeManagementAPI(context.Background(), dialer)
assert.Equal(t, ProbeTypeManagementAPI, result.Type)
assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsConnectionFailed, result.Details)
assert.Equal(t, actionAPIUnreachable, result.Action)
}
// skipResult tests.
func TestSkipResult(t *testing.T) {
t.Parallel()
result := skipResult(ProbeTypeQUIC, "UDP Connectivity", "Port 7844 (QUIC)")
assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, "UDP Connectivity", result.Component)
assert.Equal(t, "Port 7844 (QUIC)", result.Target)
assert.Equal(t, Skip, result.ProbeStatus)
assert.Equal(t, detailsDNSPrerequisiteFailed, result.Details)
}
// regionTargets tests.
func TestRegionTargets(t *testing.T) {
t.Parallel()
tests := []struct {
name string
region string
wantRegion1 string
wantRegion2 string
description string
}{
{
name: "empty region returns global hostnames",
region: "",
wantRegion1: testRegion1Global,
wantRegion2: testRegion2Global,
},
{
name: "us region returns US hostnames",
region: "us",
wantRegion1: testRegion1US,
wantRegion2: testRegion2US,
},
{
name: "fed region returns fed hostnames",
region: "fed",
wantRegion1: testRegion1Fed,
wantRegion2: testRegion2Fed,
},
{
name: "unknown region defaults to global hostnames",
region: "eu",
wantRegion1: testRegion1Global,
wantRegion2: testRegion2Global,
description: "Unknown regions should default to global hostnames",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotR1, gotR2 := regionTargets(tt.region)
assert.Equal(t, tt.wantRegion1, gotR1)
assert.Equal(t, tt.wantRegion2, gotR2)
})
}
}
// addrsByFamily tests.
func TestAddrsByFamily(t *testing.T) {
t.Parallel()
v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
v6Addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
tests := []struct {
name string
group []*allregions.EdgeAddr
ipVersion allregions.ConfigIPVersion
wantV4 bool
wantV6 bool
}{
{
name: "auto returns both v4 and v6",
group: []*allregions.EdgeAddr{v4Addr, v6Addr},
ipVersion: allregions.Auto,
wantV4: true,
wantV6: true,
},
{
name: "ipv4 only returns v4 and nil v6",
group: []*allregions.EdgeAddr{v4Addr, v6Addr},
ipVersion: allregions.IPv4Only,
wantV4: true,
wantV6: false,
},
{
name: "ipv6 only returns nil v4 and v6",
group: []*allregions.EdgeAddr{v4Addr, v6Addr},
ipVersion: allregions.IPv6Only,
wantV4: false,
wantV6: true,
},
{
name: "empty group returns nil for both",
group: []*allregions.EdgeAddr{},
ipVersion: allregions.Auto,
wantV4: false,
wantV6: false,
},
{
name: "only v4 available returns v4 and nil v6",
group: []*allregions.EdgeAddr{v4Addr},
ipVersion: allregions.Auto,
wantV4: true,
wantV6: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotV4, gotV6 := addrsByFamily(tt.group, tt.ipVersion)
if tt.wantV4 {
assert.NotNil(t, gotV4)
} else {
assert.Nil(t, gotV4)
}
if tt.wantV6 {
assert.NotNil(t, gotV6)
} else {
assert.Nil(t, gotV6)
}
})
}
}
// IPv6 address tests for probeQUIC.
func TestProbeQUIC_IPv6Address(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockConn := &mockQuicConnection{}
dialer := mocks.NewMockQUICDialer(ctrl)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
logger := zerolog.New(nil)
result := probeQUIC(context.Background(), dialer, addr, &logger)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
}
// IPv6 address tests for probeHTTP2.
func TestProbeHTTP2_IPv6Address(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
dialer := mocks.NewMockTCPDialer(ctrl)
dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&net.TCPConn{}, nil)
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
result := probeHTTP2(context.Background(), dialer, addr)
assert.Equal(t, Pass, result.ProbeStatus)
}
@@ -23,9 +23,6 @@ import (
// system resolver fails, and resolves each discovered hostname via
// net.LookupIP. The returned slice already has each address tagged with
// .IPVersion = V4 or V6.
//
// Note: allregions.EdgeDiscovery must be exported (currently unexported as
// edgeDiscovery) before a production adapter can be wired up.
type DNSResolver interface {
// Resolve performs edge discovery for the given region string (empty for
// global, "us" / "fed" for regional endpoints) and returns the resolved