diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index bffac0b4..c665146c 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -912,6 +912,7 @@ and virtualized host network stacks from each other`, } func configureProxyFlags(shouldHide bool) []cli.Flag { + //nolint: prealloc flags := []cli.Flag{ altsrc.NewStringFlag(&cli.StringFlag{ Name: "url", diff --git a/edgediscovery/allregions/discovery.go b/edgediscovery/allregions/discovery.go index cab06611..52699229 100644 --- a/edgediscovery/allregions/discovery.go +++ b/edgediscovery/allregions/discovery.go @@ -109,7 +109,7 @@ var friendlyDNSErrorLines = []string{ } // EdgeDiscovery implements HA service discovery lookup. -func edgeDiscovery(log *zerolog.Logger, srvService string) ([][]*EdgeAddr, error) { +func EdgeDiscovery(log *zerolog.Logger, srvService string) ([][]*EdgeAddr, error) { logger := log.With().Int(management.EventTypeKey, int(management.Cloudflared)).Logger() logger.Debug(). Int(management.EventTypeKey, int(management.Cloudflared)). diff --git a/edgediscovery/allregions/discovery_test.go b/edgediscovery/allregions/discovery_test.go index 89ed13d5..7b854e6c 100644 --- a/edgediscovery/allregions/discovery_test.go +++ b/edgediscovery/allregions/discovery_test.go @@ -6,6 +6,7 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func (ea *EdgeAddr) String() string { @@ -25,8 +26,8 @@ func TestEdgeDiscovery(t *testing.T) { } l := zerolog.Nop() - addrLists, err := edgeDiscovery(&l, "") - assert.NoError(t, err) + addrLists, err := EdgeDiscovery(&l, "") + require.NoError(t, err) actualAddrSet := map[string]bool{} for _, addrs := range addrLists { for _, addr := range addrs { diff --git a/edgediscovery/allregions/regions.go b/edgediscovery/allregions/regions.go index b9b7d3ea..d5fc6ca6 100644 --- a/edgediscovery/allregions/regions.go +++ b/edgediscovery/allregions/regions.go @@ -20,7 +20,7 @@ type Regions struct { // ResolveEdge resolves the Cloudflare edge, returning all regions discovered. func ResolveEdge(log *zerolog.Logger, region string, overrideIPVersion ConfigIPVersion) (*Regions, error) { - edgeAddrs, err := edgeDiscovery(log, getRegionalServiceName(region)) + edgeAddrs, err := EdgeDiscovery(log, RegionalServiceName(region)) if err != nil { return nil, err } @@ -91,6 +91,7 @@ func (rs *Regions) GetUnusedAddr(excluding *EdgeAddr, connID int) *EdgeAddr { // evenly across both regions. if rs.region1.AvailableAddrs() == rs.region2.AvailableAddrs() { regions := []Region{rs.region1, rs.region2} + //nolint:gosec firstChoice := rand.Intn(2) return getAddrs(excluding, connID, ®ions[firstChoice], ®ions[1-firstChoice]) } @@ -131,11 +132,13 @@ func (rs *Regions) GiveBack(addr *EdgeAddr, hasConnectivityError bool) bool { return rs.region2.GiveBack(addr, hasConnectivityError) } -// Return regionalized service name if `region` isn't empty, otherwise return the global service name for origintunneld -func getRegionalServiceName(region string) string { +// RegionalServiceName returns the SRV service name for the given region. +// When region is empty it returns the global service name ("v2-origintunneld"). +// Otherwise, it prepends the region, e.g. "us-v2-origintunneld". +func RegionalServiceName(region string) string { if region != "" { - return region + "-" + srvService // Example: `us-v2-origintunneld` + return region + "-" + srvService } - return srvService // Global service is just `v2-origintunneld` + return srvService } diff --git a/edgediscovery/allregions/regions_test.go b/edgediscovery/allregions/regions_test.go index e399c4ee..70769246 100644 --- a/edgediscovery/allregions/regions_test.go +++ b/edgediscovery/allregions/regions_test.go @@ -237,21 +237,19 @@ func TestNewNoResolveBalancesRegions(t *testing.T) { } } -func TestGetRegionalServiceName(t *testing.T) { +func TestRegionalServiceName(t *testing.T) { // Empty region should just go to origintunneld - globalServiceName := getRegionalServiceName("") - assert.Equal(t, srvService, globalServiceName) + assert.Equal(t, srvService, RegionalServiceName("")) // Non-empty region should go to the regional origintunneld variant for _, region := range []string{"us", "pt", "am"} { - regionalServiceName := getRegionalServiceName(region) - assert.Equal(t, region+"-"+srvService, regionalServiceName) + assert.Equal(t, region+"-"+srvService, RegionalServiceName(region)) } } func RegionsIsBalanced(t *testing.T, rs *Regions) { delta := rs.region1.AvailableAddrs() - rs.region2.AvailableAddrs() - assert.True(t, abs(delta) <= 1) + assert.LessOrEqual(t, abs(delta), 1) } func abs(x int) int { diff --git a/mocks/mock_resolvers.go b/mocks/mock_resolvers.go new file mode 100644 index 00000000..96c41cea --- /dev/null +++ b/mocks/mock_resolvers.go @@ -0,0 +1,278 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../prechecks/resolvers.go +// +// Generated by this command: +// +// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_resolvers.go -source=../prechecks/resolvers.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + tls "crypto/tls" + net "net" + netip "net/netip" + reflect "reflect" + time "time" + + quic "github.com/quic-go/quic-go" + zerolog "github.com/rs/zerolog" + gomock "go.uber.org/mock/gomock" + + dialopts "github.com/cloudflare/cloudflared/connection/dialopts" + allregions "github.com/cloudflare/cloudflared/edgediscovery/allregions" +) + +// MockDNSResolver is a mock of DNSResolver interface. +type MockDNSResolver struct { + ctrl *gomock.Controller + recorder *MockDNSResolverMockRecorder + isgomock struct{} +} + +// MockDNSResolverMockRecorder is the mock recorder for MockDNSResolver. +type MockDNSResolverMockRecorder struct { + mock *MockDNSResolver +} + +// NewMockDNSResolver creates a new mock instance. +func NewMockDNSResolver(ctrl *gomock.Controller) *MockDNSResolver { + mock := &MockDNSResolver{ctrl: ctrl} + mock.recorder = &MockDNSResolverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDNSResolver) EXPECT() *MockDNSResolverMockRecorder { + return m.recorder +} + +// Resolve mocks base method. +func (m *MockDNSResolver) Resolve(region string) ([][]*allregions.EdgeAddr, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Resolve", region) + ret0, _ := ret[0].([][]*allregions.EdgeAddr) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Resolve indicates an expected call of Resolve. +func (mr *MockDNSResolverMockRecorder) Resolve(region any) *MockDNSResolverResolveCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockDNSResolver)(nil).Resolve), region) + return &MockDNSResolverResolveCall{Call: call} +} + +// MockDNSResolverResolveCall wrap *gomock.Call +type MockDNSResolverResolveCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDNSResolverResolveCall) Return(arg0 [][]*allregions.EdgeAddr, arg1 error) *MockDNSResolverResolveCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDNSResolverResolveCall) Do(f func(string) ([][]*allregions.EdgeAddr, error)) *MockDNSResolverResolveCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDNSResolverResolveCall) DoAndReturn(f func(string) ([][]*allregions.EdgeAddr, error)) *MockDNSResolverResolveCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockTCPDialer is a mock of TCPDialer interface. +type MockTCPDialer struct { + ctrl *gomock.Controller + recorder *MockTCPDialerMockRecorder + isgomock struct{} +} + +// MockTCPDialerMockRecorder is the mock recorder for MockTCPDialer. +type MockTCPDialerMockRecorder struct { + mock *MockTCPDialer +} + +// NewMockTCPDialer creates a new mock instance. +func NewMockTCPDialer(ctrl *gomock.Controller) *MockTCPDialer { + mock := &MockTCPDialer{ctrl: ctrl} + mock.recorder = &MockTCPDialerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTCPDialer) EXPECT() *MockTCPDialerMockRecorder { + return m.recorder +} + +// DialEdge mocks base method. +func (m *MockTCPDialer) DialEdge(ctx context.Context, timeout time.Duration, tlsConfig *tls.Config, addr *net.TCPAddr, localIP net.IP) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DialEdge", ctx, timeout, tlsConfig, addr, localIP) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DialEdge indicates an expected call of DialEdge. +func (mr *MockTCPDialerMockRecorder) DialEdge(ctx, timeout, tlsConfig, addr, localIP any) *MockTCPDialerDialEdgeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialEdge", reflect.TypeOf((*MockTCPDialer)(nil).DialEdge), ctx, timeout, tlsConfig, addr, localIP) + return &MockTCPDialerDialEdgeCall{Call: call} +} + +// MockTCPDialerDialEdgeCall wrap *gomock.Call +type MockTCPDialerDialEdgeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTCPDialerDialEdgeCall) Return(arg0 net.Conn, arg1 error) *MockTCPDialerDialEdgeCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTCPDialerDialEdgeCall) Do(f func(context.Context, time.Duration, *tls.Config, *net.TCPAddr, net.IP) (net.Conn, error)) *MockTCPDialerDialEdgeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTCPDialerDialEdgeCall) DoAndReturn(f func(context.Context, time.Duration, *tls.Config, *net.TCPAddr, net.IP) (net.Conn, error)) *MockTCPDialerDialEdgeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockQUICDialer is a mock of QUICDialer interface. +type MockQUICDialer struct { + ctrl *gomock.Controller + recorder *MockQUICDialerMockRecorder + isgomock struct{} +} + +// MockQUICDialerMockRecorder is the mock recorder for MockQUICDialer. +type MockQUICDialerMockRecorder struct { + mock *MockQUICDialer +} + +// NewMockQUICDialer creates a new mock instance. +func NewMockQUICDialer(ctrl *gomock.Controller) *MockQUICDialer { + mock := &MockQUICDialer{ctrl: ctrl} + mock.recorder = &MockQUICDialerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockQUICDialer) EXPECT() *MockQUICDialerMockRecorder { + return m.recorder +} + +// DialQuic mocks base method. +func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.Connection, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DialQuic", ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts) + ret0, _ := ret[0].(quic.Connection) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DialQuic indicates an expected call of DialQuic. +func (mr *MockQUICDialerMockRecorder) DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts any) *MockQUICDialerDialQuicCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialQuic", reflect.TypeOf((*MockQUICDialer)(nil).DialQuic), ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts) + return &MockQUICDialerDialQuicCall{Call: call} +} + +// MockQUICDialerDialQuicCall wrap *gomock.Call +type MockQUICDialerDialQuicCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.Connection, arg1 error) *MockQUICDialerDialQuicCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockManagementDialer is a mock of ManagementDialer interface. +type MockManagementDialer struct { + ctrl *gomock.Controller + recorder *MockManagementDialerMockRecorder + isgomock struct{} +} + +// MockManagementDialerMockRecorder is the mock recorder for MockManagementDialer. +type MockManagementDialerMockRecorder struct { + mock *MockManagementDialer +} + +// NewMockManagementDialer creates a new mock instance. +func NewMockManagementDialer(ctrl *gomock.Controller) *MockManagementDialer { + mock := &MockManagementDialer{ctrl: ctrl} + mock.recorder = &MockManagementDialerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockManagementDialer) EXPECT() *MockManagementDialerMockRecorder { + return m.recorder +} + +// DialContext mocks base method. +func (m *MockManagementDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DialContext", ctx, network, addr) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DialContext indicates an expected call of DialContext. +func (mr *MockManagementDialerMockRecorder) DialContext(ctx, network, addr any) *MockManagementDialerDialContextCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockManagementDialer)(nil).DialContext), ctx, network, addr) + return &MockManagementDialerDialContextCall{Call: call} +} + +// MockManagementDialerDialContextCall wrap *gomock.Call +type MockManagementDialerDialContextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockManagementDialerDialContextCall) Return(arg0 net.Conn, arg1 error) *MockManagementDialerDialContextCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockManagementDialerDialContextCall) Do(f func(context.Context, string, string) (net.Conn, error)) *MockManagementDialerDialContextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockManagementDialerDialContextCall) DoAndReturn(f func(context.Context, string, string) (net.Conn, error)) *MockManagementDialerDialContextCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/mocks/mockgen.go b/mocks/mockgen.go index bb68ee31..58c4124e 100644 --- a/mocks/mockgen.go +++ b/mocks/mockgen.go @@ -3,3 +3,5 @@ package mocks //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter" + +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_resolvers.go -source=../prechecks/resolvers.go" diff --git a/prechecks/probes.go b/prechecks/probes.go new file mode 100644 index 00000000..31dd307a --- /dev/null +++ b/prechecks/probes.go @@ -0,0 +1,338 @@ +package prechecks + +import ( + "context" + "crypto/tls" + "fmt" + "math" + "net" + "net/netip" + "time" + + "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/connection/dialopts" + + "github.com/cloudflare/cloudflared/connection" + edgedial "github.com/cloudflare/cloudflared/edgediscovery" + "github.com/cloudflare/cloudflared/edgediscovery/allregions" +) + +const ( + perProbeDialTimeout = 5 * time.Second + + // Action messages for each probe outcome. + actionDNSFail = "Ensure your DNS resolver can resolve '%s'. Run: dig A %s @1.1.1.1. If that fails, contact your network administrator." + actionQUICBlocked = "QUIC traffic failed to connect to port 7844." + actionHTTP2Blocked = "Allow outbound TCP on port 7844." + actionAPIUnreachable = "cloudflared will still run, but automatic software updates are unavailable. " + + "Ensure port 443 TCP to api.cloudflare.com is open if you want auto-updates." + + // Component names for CheckResult. + componentDNSResolution = "DNS Resolution" + componentUDPConnectivity = "UDP Connectivity" + componentTCPConnectivity = "TCP Connectivity" + componentCloudflareAPI = "Cloudflare API" + + // Target identifiers for CheckResult. + targetPortQUIC = "Port 7844 (QUIC)" + targetPortHTTP2 = "Port 7844 (HTTP/2)" + targetAPI = "api.cloudflare.com:443" + + // Details messages for CheckResult. + detailsNoAddressesReturned = "No addresses returned" + detailsResolvedSuccessfully = "Resolved successfully" + detailsHandshakeFailed = "Handshake failed" + detailsHandshakeSuccessful = "Handshake successful" + detailsBlockedOrUnreachable = "Blocked or unreachable" + detailsTLSHandshakeSuccessful = "TLS handshake successful" + detailsConnectionFailed = "Connection failed" + detailsTCPPortReachable = "TCP port reachable (TLS not validated)" + detailsDNSPrerequisiteFailed = "DNS prerequisite failed" + + // Region hostname templates. + region1Global = "region1.v2.argotunnel.com" + region2Global = "region2.v2.argotunnel.com" + region1US = "us-region1.v2.argotunnel.com" + region2US = "us-region2.v2.argotunnel.com" + region1Fed = "fed-region1.v2.argotunnel.com" + region2Fed = "fed-region2.v2.argotunnel.com" +) + +type EdgeDNSResolver struct { + Log *zerolog.Logger +} + +func (r *EdgeDNSResolver) Resolve(region string) ([][]*allregions.EdgeAddr, error) { + return allregions.EdgeDiscovery(r.Log, allregions.RegionalServiceName(region)) +} + +type EdgeTCPDialer struct{} + +func (d *EdgeTCPDialer) DialEdge( + ctx context.Context, + timeout time.Duration, + tlsConfig *tls.Config, + addr *net.TCPAddr, + localIP net.IP, +) (net.Conn, error) { + return edgedial.DialEdge(ctx, timeout, tlsConfig, addr, localIP) +} + +type EdgeQUICDialer struct{} + +func (d *EdgeQUICDialer) DialQuic( + ctx context.Context, + quicConfig *quic.Config, + tlsConfig *tls.Config, + addr netip.AddrPort, + localAddr net.IP, + connIndex uint8, + logger *zerolog.Logger, + opts dialopts.DialOpts, +) (quic.Connection, error) { + return connection.DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts) +} + +type NetManagementDialer struct { + Dialer net.Dialer +} + +func (d *NetManagementDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return d.Dialer.DialContext(ctx, network, addr) +} + +// probeDNS resolves edge addresses for the given region via the supplied +// DNSResolver and returns a CheckResult for each region discovered. If +// resolution fails for all regions, every result will carry StatusFail. +func probeDNS( + resolver DNSResolver, + region string, +) ([][]*allregions.EdgeAddr, []CheckResult) { + addrGroups, err := resolver.Resolve(region) + if err != nil || len(addrGroups) == 0 { + detail := detailsNoAddressesReturned + if err != nil { + detail = err.Error() + } + region1Target, region2Target := regionTargets(region) + return nil, []CheckResult{ + { + Type: ProbeTypeDNS, + Component: componentDNSResolution, + Target: region1Target, + ProbeStatus: Fail, + Details: detail, + Action: fmt.Sprintf(actionDNSFail, region1Target, region1Target), + }, + { + Type: ProbeTypeDNS, + Component: componentDNSResolution, + Target: region2Target, + ProbeStatus: Fail, + Details: detail, + Action: fmt.Sprintf(actionDNSFail, region2Target, region2Target), + }, + } + } + + region1Target, region2Target := regionTargets(region) + targets := []string{region1Target, region2Target} + + results := make([]CheckResult, 0, len(addrGroups)) + for i, group := range addrGroups { + target := fmt.Sprintf("region%d.v2.argotunnel.com", i+1) + if i < len(targets) { + target = targets[i] + } + if len(group) == 0 { + results = append(results, CheckResult{ + Type: ProbeTypeDNS, + Component: componentDNSResolution, + Target: target, + ProbeStatus: Fail, + Details: detailsNoAddressesReturned, + Action: fmt.Sprintf(actionDNSFail, target, target), + }) + } else { + results = append(results, CheckResult{ + Type: ProbeTypeDNS, + Component: componentDNSResolution, + Target: target, + ProbeStatus: Pass, + Details: detailsResolvedSuccessfully, + }) + } + } + + return addrGroups, results +} + +// probeQUIC performs a QUIC handshake to a single edge address and returns a +// CheckResult. The connection is closed immediately after the handshake – no +// streams are opened and no RPC frames are sent – to avoid triggering the OTD +// registration timeout (TUN-6732). The probe SNI (probe.cftunnel.com) is used +// instead of the production quic.cftunnel.com to prevent OTD log noise. +// +// A per-probe deadline (perProbeDialTimeout) is applied on top of the parent +// context so that a single blocked handshake cannot consume the entire suite +// budget. +func probeQUIC( + ctx context.Context, + dialer QUICDialer, + addr *allregions.EdgeAddr, + logger *zerolog.Logger, +) CheckResult { + dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout) + defer cancel() + + tlsSettings := connection.QUIC.ProbeTLSSettings() + tlsConfig := &tls.Config{ + ServerName: tlsSettings.ServerName, + NextProtos: tlsSettings.NextProtos, + MinVersion: tls.VersionTLS13, + CurvePreferences: []tls.CurveID{tls.CurveP256}, + } + + // We call dialer.DialQuic with isProbe = true, which bypasses connIndex check. + // Therefore, whatever we add to connIndex will not be relevant. + edgeAddrPort := addr.UDP.AddrPort() + conn, err := dialer.DialQuic( + dialCtx, + &quic.Config{}, + tlsConfig, + edgeAddrPort, + nil, + math.MaxUint8, + logger, + dialopts.DialOpts{SkipPortReuse: true}, + ) + if err != nil { + return CheckResult{ + Type: ProbeTypeQUIC, + Component: componentUDPConnectivity, + Target: targetPortQUIC, + ProbeStatus: Fail, + Details: detailsHandshakeFailed, + Action: actionQUICBlocked, + } + } + + if err := conn.CloseWithError(0, "precheck complete"); err != nil { + logger.Debug().Err(err).Msg("Failed to close QUIC connection after successful handshake") + } + + return CheckResult{ + Type: ProbeTypeQUIC, + Component: componentUDPConnectivity, + Target: targetPortQUIC, + ProbeStatus: Pass, + Details: detailsHandshakeSuccessful, + } +} + +// probeHTTP2 performs a TCP + TLS handshake to a single edge address and +// returns a CheckResult. The connection is closed immediately after the +// handshake – no HTTP/2 frames are sent – to keep the probe minimal. The probe +// SNI (probe.cftunnel.com) is used instead of the production h2.cftunnel.com +// to prevent OTD log noise. +// +// The dial timeout is capped at perProbeDialTimeout so that a single blocked +// dial cannot exhaust the entire suite budget. +func probeHTTP2(ctx context.Context, dialer TCPDialer, addr *allregions.EdgeAddr) CheckResult { + tlsSettings := connection.HTTP2.ProbeTLSSettings() + tlsConfig := &tls.Config{ + ServerName: tlsSettings.ServerName, + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{tls.CurveP256}, + } + + conn, err := dialer.DialEdge(ctx, perProbeDialTimeout, tlsConfig, addr.TCP, nil) + if err != nil { + return CheckResult{ + Type: ProbeTypeHTTP2, + Component: componentTCPConnectivity, + Target: targetPortHTTP2, + ProbeStatus: Fail, + Details: detailsBlockedOrUnreachable, + Action: actionHTTP2Blocked, + } + } + _ = conn.Close() + + return CheckResult{ + Type: ProbeTypeHTTP2, + Component: componentTCPConnectivity, + Target: targetPortHTTP2, + ProbeStatus: Pass, + Details: detailsTLSHandshakeSuccessful, + } +} + +// probeManagementAPI tests TCP connectivity to api.cloudflare.com:443. A +// successful TCP connection (no TLS handshake required) confirms the port is +// reachable. This probe is always a soft failure: the tunnel can run without +// it, but automatic software updates will be unavailable. +func probeManagementAPI(ctx context.Context, dialer ManagementDialer) CheckResult { + dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout) + defer cancel() + + conn, err := dialer.DialContext(dialCtx, "tcp", targetAPI) + if err != nil { + return CheckResult{ + Type: ProbeTypeManagementAPI, + Component: componentCloudflareAPI, + Target: targetAPI, + ProbeStatus: Fail, + Details: detailsConnectionFailed, + Action: actionAPIUnreachable, + } + } + _ = conn.Close() + + return CheckResult{ + Type: ProbeTypeManagementAPI, + Component: componentCloudflareAPI, + Target: targetAPI, + ProbeStatus: Pass, + Details: detailsTCPPortReachable, + } +} + +func skipResult(probeType ProbeType, component, target string) CheckResult { + return CheckResult{ + Type: probeType, + Component: component, + Target: target, + ProbeStatus: Skip, + Details: detailsDNSPrerequisiteFailed, + } +} + +// regionTargets returns the human-readable hostnames for region1 and region2 +// based on the optional region flag value. +func regionTargets(region string) (string, string) { + switch region { + case "us": + return region1US, region2US + case "fed": + return region1Fed, region2Fed + default: + return region1Global, region2Global + } +} + +// addrsByFamily extracts one V4 and one V6 address from a resolved CNAME group +// using allregions.NewRegion so that the IP-version preference logic matches +// production exactly. When cfg.IPVersion restricts to a single family the +// excluded family's pointer is nil. +func addrsByFamily(group []*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) (v4, v6 *allregions.EdgeAddr) { + if ipVersion != allregions.IPv6Only { + v4 = allregions.NewRegion(group, allregions.IPv4Only).GetAnyAddress() + } + if ipVersion != allregions.IPv4Only { + v6 = allregions.NewRegion(group, allregions.IPv6Only).GetAnyAddress() + } + return +} diff --git a/prechecks/probes_test.go b/prechecks/probes_test.go new file mode 100644 index 00000000..d8ce7922 --- /dev/null +++ b/prechecks/probes_test.go @@ -0,0 +1,536 @@ +package prechecks + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/cloudflare/cloudflared/edgediscovery/allregions" + "github.com/cloudflare/cloudflared/mocks" +) + +// Test constants for repeated string values. +const ( + testRegion1Global = region1Global + testRegion2Global = region2Global + testRegion1US = region1US + testRegion2US = region2US + testRegion1Fed = region1Fed + testRegion2Fed = region2Fed + testRegion1EU = "eu-region1.v2.argotunnel.com" + testRegion2EU = "eu-region2.v2.argotunnel.com" + + testEdgePort = 7844 +) + +// mockQuicConnection is a minimal test double for quic.Connection. +type mockQuicConnection struct { + closeErr error +} + +func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) { + return nil, nil +} + +func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) { + return nil, nil +} + +func (m *mockQuicConnection) OpenStream() (quic.Stream, error) { + return nil, nil +} + +func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) { + return nil, nil +} + +func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) { + return nil, nil +} + +func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) { + return nil, nil +} + +func (m *mockQuicConnection) LocalAddr() net.Addr { + return nil +} + +func (m *mockQuicConnection) RemoteAddr() net.Addr { + return nil +} + +func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, _ string) error { + return m.closeErr +} + +func (m *mockQuicConnection) Context() context.Context { + return context.Background() +} + +func (m *mockQuicConnection) ConnectionState() quic.ConnectionState { + return quic.ConnectionState{} +} + +func (m *mockQuicConnection) SendDatagram(_ []byte) error { + return nil +} + +func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) { + return nil, nil +} + +func (m *mockQuicConnection) AddPath(*quic.Transport) (*quic.Path, error) { + return nil, nil +} + +// Helper to create test edge addresses. +func createTestEdgeAddr(ip string, port int, version allregions.EdgeIPVersion) *allregions.EdgeAddr { + parsedIP := net.ParseIP(ip) + return &allregions.EdgeAddr{ + TCP: &net.TCPAddr{IP: parsedIP, Port: port}, + UDP: &net.UDPAddr{IP: parsedIP, Port: port}, + IPVersion: version, + } +} + +// probeDNS tests. + +func TestProbeDNS_Success(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + v6Addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6) + + resolver := mocks.NewMockDNSResolver(ctrl) + resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr, v6Addr}}, nil) + + addrs, results := probeDNS(resolver, "") + + require.NotNil(t, addrs) + require.Len(t, results, 1) + assert.Len(t, addrs, 1) + assert.Equal(t, ProbeTypeDNS, results[0].Type) + assert.Equal(t, testRegion1Global, results[0].Target) + assert.Equal(t, Pass, results[0].ProbeStatus) + assert.Equal(t, detailsResolvedSuccessfully, results[0].Details) +} + +func TestProbeDNS_MultipleRegions(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + v4Addr1 := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + v4Addr2 := createTestEdgeAddr("192.0.2.2", testEdgePort, allregions.V4) + + resolver := mocks.NewMockDNSResolver(ctrl) + resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr1}, {v4Addr2}}, nil) + + addrs, results := probeDNS(resolver, "") + + require.NotNil(t, addrs) + require.Len(t, results, 2) + assert.Len(t, addrs, 2) + + assert.Equal(t, testRegion1Global, results[0].Target) + assert.Equal(t, Pass, results[0].ProbeStatus) + + assert.Equal(t, testRegion2Global, results[1].Target) + assert.Equal(t, Pass, results[1].ProbeStatus) +} + +func TestProbeDNS_ResolverError(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + resolver := mocks.NewMockDNSResolver(ctrl) + resolver.EXPECT().Resolve("").Return(nil, errors.New("DNS lookup failed")) + + addrs, results := probeDNS(resolver, "") + + assert.Nil(t, addrs) + require.Len(t, results, 2) + + assert.Equal(t, Fail, results[0].ProbeStatus) + assert.Equal(t, "DNS lookup failed", results[0].Details) + assert.Contains(t, results[0].Action, testRegion1Global) + assert.Contains(t, results[1].Action, testRegion2Global) + + assert.Equal(t, Fail, results[1].ProbeStatus) +} + +func TestProbeDNS_EmptyResults(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + resolver := mocks.NewMockDNSResolver(ctrl) + resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{}, nil) + + addrs, results := probeDNS(resolver, "") + + assert.Nil(t, addrs) + require.Len(t, results, 2) + assert.Equal(t, Fail, results[0].ProbeStatus) + assert.Equal(t, "No addresses returned", results[0].Details) +} + +func TestProbeDNS_EmptyGroup(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + resolver := mocks.NewMockDNSResolver(ctrl) + resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{}}, nil) + + addrs, results := probeDNS(resolver, "") + + require.NotNil(t, addrs) + require.Len(t, results, 1) + assert.Equal(t, Fail, results[0].ProbeStatus) + assert.Equal(t, "No addresses returned", results[0].Details) +} + +func TestProbeDNS_RegionFlag(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + resolver := mocks.NewMockDNSResolver(ctrl) + resolver.EXPECT().Resolve("us").Return([][]*allregions.EdgeAddr{{v4Addr}}, nil) + + _, results := probeDNS(resolver, "us") + + require.Len(t, results, 1) + assert.Equal(t, testRegion1US, results[0].Target) +} + +// probeQUIC tests. + +func TestProbeQUIC_Success(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := &mockQuicConnection{} + dialer := mocks.NewMockQUICDialer(ctrl) + dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil) + + addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + logger := zerolog.New(nil) + + result := probeQUIC(context.Background(), dialer, addr, &logger) + + assert.Equal(t, ProbeTypeQUIC, result.Type) + assert.Equal(t, Pass, result.ProbeStatus) + assert.Equal(t, detailsHandshakeSuccessful, result.Details) +} + +func TestProbeQUIC_DialError(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dialer := mocks.NewMockQUICDialer(ctrl) + dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("connection refused")) + + addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + logger := zerolog.New(nil) + + result := probeQUIC(context.Background(), dialer, addr, &logger) + + assert.Equal(t, ProbeTypeQUIC, result.Type) + assert.Equal(t, Fail, result.ProbeStatus) + assert.Equal(t, detailsHandshakeFailed, result.Details) + assert.Equal(t, actionQUICBlocked, result.Action) +} + +func TestProbeQUIC_CloseErrorDoesNotAffectResult(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := &mockQuicConnection{closeErr: errors.New("close failed")} + dialer := mocks.NewMockQUICDialer(ctrl) + dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil) + + addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + logger := zerolog.New(nil) + + result := probeQUIC(context.Background(), dialer, addr, &logger) + + assert.Equal(t, ProbeTypeQUIC, result.Type) + assert.Equal(t, Pass, result.ProbeStatus) + assert.Equal(t, detailsHandshakeSuccessful, result.Details) +} + +func TestProbeQUIC_ContextTimeout(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dialer := mocks.NewMockQUICDialer(ctrl) + dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, context.DeadlineExceeded) + + addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + logger := zerolog.New(nil) + + result := probeQUIC(context.Background(), dialer, addr, &logger) + + assert.Equal(t, Fail, result.ProbeStatus) + assert.Equal(t, detailsHandshakeFailed, result.Details) +} + +// probeHTTP2 tests. + +func TestProbeHTTP2_Success(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dialer := mocks.NewMockTCPDialer(ctrl) + dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&net.TCPConn{}, nil) + + addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + + result := probeHTTP2(context.Background(), dialer, addr) + + assert.Equal(t, ProbeTypeHTTP2, result.Type) + assert.Equal(t, Pass, result.ProbeStatus) + assert.Equal(t, detailsTLSHandshakeSuccessful, result.Details) +} + +func TestProbeHTTP2_DialError(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dialer := mocks.NewMockTCPDialer(ctrl) + dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("connection refused")) + + addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + + result := probeHTTP2(context.Background(), dialer, addr) + + assert.Equal(t, ProbeTypeHTTP2, result.Type) + assert.Equal(t, Fail, result.ProbeStatus) + assert.Equal(t, detailsBlockedOrUnreachable, result.Details) + assert.Equal(t, actionHTTP2Blocked, result.Action) +} + +// probeManagementAPI tests. + +func TestProbeManagementAPI_Success(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dialer := mocks.NewMockManagementDialer(ctrl) + dialer.EXPECT().DialContext(gomock.Any(), "tcp", "api.cloudflare.com:443").Return(&net.TCPConn{}, nil) + + result := probeManagementAPI(context.Background(), dialer) + + assert.Equal(t, ProbeTypeManagementAPI, result.Type) + assert.Equal(t, "Cloudflare API", result.Component) + assert.Equal(t, "api.cloudflare.com:443", result.Target) + assert.Equal(t, Pass, result.ProbeStatus) + assert.Equal(t, detailsTCPPortReachable, result.Details) +} + +func TestProbeManagementAPI_DialError(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dialer := mocks.NewMockManagementDialer(ctrl) + dialer.EXPECT().DialContext(gomock.Any(), "tcp", "api.cloudflare.com:443").Return(nil, errors.New("connection refused")) + + result := probeManagementAPI(context.Background(), dialer) + + assert.Equal(t, ProbeTypeManagementAPI, result.Type) + assert.Equal(t, Fail, result.ProbeStatus) + assert.Equal(t, detailsConnectionFailed, result.Details) + assert.Equal(t, actionAPIUnreachable, result.Action) +} + +// skipResult tests. + +func TestSkipResult(t *testing.T) { + t.Parallel() + + result := skipResult(ProbeTypeQUIC, "UDP Connectivity", "Port 7844 (QUIC)") + + assert.Equal(t, ProbeTypeQUIC, result.Type) + assert.Equal(t, "UDP Connectivity", result.Component) + assert.Equal(t, "Port 7844 (QUIC)", result.Target) + assert.Equal(t, Skip, result.ProbeStatus) + assert.Equal(t, detailsDNSPrerequisiteFailed, result.Details) +} + +// regionTargets tests. + +func TestRegionTargets(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + region string + wantRegion1 string + wantRegion2 string + description string + }{ + { + name: "empty region returns global hostnames", + region: "", + wantRegion1: testRegion1Global, + wantRegion2: testRegion2Global, + }, + { + name: "us region returns US hostnames", + region: "us", + wantRegion1: testRegion1US, + wantRegion2: testRegion2US, + }, + { + name: "fed region returns fed hostnames", + region: "fed", + wantRegion1: testRegion1Fed, + wantRegion2: testRegion2Fed, + }, + { + name: "unknown region defaults to global hostnames", + region: "eu", + wantRegion1: testRegion1Global, + wantRegion2: testRegion2Global, + description: "Unknown regions should default to global hostnames", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + gotR1, gotR2 := regionTargets(tt.region) + assert.Equal(t, tt.wantRegion1, gotR1) + assert.Equal(t, tt.wantRegion2, gotR2) + }) + } +} + +// addrsByFamily tests. + +func TestAddrsByFamily(t *testing.T) { + t.Parallel() + + v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4) + v6Addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6) + + tests := []struct { + name string + group []*allregions.EdgeAddr + ipVersion allregions.ConfigIPVersion + wantV4 bool + wantV6 bool + }{ + { + name: "auto returns both v4 and v6", + group: []*allregions.EdgeAddr{v4Addr, v6Addr}, + ipVersion: allregions.Auto, + wantV4: true, + wantV6: true, + }, + { + name: "ipv4 only returns v4 and nil v6", + group: []*allregions.EdgeAddr{v4Addr, v6Addr}, + ipVersion: allregions.IPv4Only, + wantV4: true, + wantV6: false, + }, + { + name: "ipv6 only returns nil v4 and v6", + group: []*allregions.EdgeAddr{v4Addr, v6Addr}, + ipVersion: allregions.IPv6Only, + wantV4: false, + wantV6: true, + }, + { + name: "empty group returns nil for both", + group: []*allregions.EdgeAddr{}, + ipVersion: allregions.Auto, + wantV4: false, + wantV6: false, + }, + { + name: "only v4 available returns v4 and nil v6", + group: []*allregions.EdgeAddr{v4Addr}, + ipVersion: allregions.Auto, + wantV4: true, + wantV6: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + gotV4, gotV6 := addrsByFamily(tt.group, tt.ipVersion) + if tt.wantV4 { + assert.NotNil(t, gotV4) + } else { + assert.Nil(t, gotV4) + } + if tt.wantV6 { + assert.NotNil(t, gotV6) + } else { + assert.Nil(t, gotV6) + } + }) + } +} + +// IPv6 address tests for probeQUIC. + +func TestProbeQUIC_IPv6Address(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := &mockQuicConnection{} + dialer := mocks.NewMockQUICDialer(ctrl) + dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil) + + addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6) + logger := zerolog.New(nil) + + result := probeQUIC(context.Background(), dialer, addr, &logger) + + assert.Equal(t, Pass, result.ProbeStatus) + assert.Equal(t, detailsHandshakeSuccessful, result.Details) +} + +// IPv6 address tests for probeHTTP2. + +func TestProbeHTTP2_IPv6Address(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dialer := mocks.NewMockTCPDialer(ctrl) + dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&net.TCPConn{}, nil) + + addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6) + + result := probeHTTP2(context.Background(), dialer, addr) + + assert.Equal(t, Pass, result.ProbeStatus) +} diff --git a/prechecks/interfaces.go b/prechecks/resolvers.go similarity index 95% rename from prechecks/interfaces.go rename to prechecks/resolvers.go index 28944881..0262230e 100644 --- a/prechecks/interfaces.go +++ b/prechecks/resolvers.go @@ -23,9 +23,6 @@ import ( // system resolver fails, and resolves each discovered hostname via // net.LookupIP. The returned slice already has each address tagged with // .IPVersion = V4 or V6. -// -// Note: allregions.EdgeDiscovery must be exported (currently unexported as -// edgeDiscovery) before a production adapter can be wired up. type DNSResolver interface { // Resolve performs edge discovery for the given region string (empty for // global, "us" / "fed" for regional endpoints) and returns the resolved