mirror of
https://github.com/cloudflare/cloudflared.git
synced 2026-06-23 04:10:20 +00:00
TUN-10388 Implement dialers for connectivity checks
This PR implements all the dialers and resolvers needed to make pre-checks happen. So this task focuses on the following: 1. Implement the DNS probe: call DNSResolver.Resolve(region) 2. Implement the QUIC probe: call QUICDialer.DialQuic (handshake only, no stream opened) and record the result. 3. Implement the HTTP/2 probe: call TCPDialer.DialEdge (TCP + TLS handshake only, no frames sent) and record the result. 4. Implement the Management API probe: call ManagementDialer.DialContext to api.cloudflare.com:443 and record the result. 5. Export edgeDiscovery as EdgeDiscovery in edgediscovery/allregions/discovery.go so the pre-check can reuse the production DNS path. This sets up the main components to implement the checker.
This commit is contained in:
@@ -912,6 +912,7 @@ and virtualized host network stacks from each other`,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func configureProxyFlags(shouldHide bool) []cli.Flag {
|
func configureProxyFlags(shouldHide bool) []cli.Flag {
|
||||||
|
//nolint: prealloc
|
||||||
flags := []cli.Flag{
|
flags := []cli.Flag{
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "url",
|
Name: "url",
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ var friendlyDNSErrorLines = []string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EdgeDiscovery implements HA service discovery lookup.
|
// EdgeDiscovery implements HA service discovery lookup.
|
||||||
func edgeDiscovery(log *zerolog.Logger, srvService string) ([][]*EdgeAddr, error) {
|
func EdgeDiscovery(log *zerolog.Logger, srvService string) ([][]*EdgeAddr, error) {
|
||||||
logger := log.With().Int(management.EventTypeKey, int(management.Cloudflared)).Logger()
|
logger := log.With().Int(management.EventTypeKey, int(management.Cloudflared)).Logger()
|
||||||
logger.Debug().
|
logger.Debug().
|
||||||
Int(management.EventTypeKey, int(management.Cloudflared)).
|
Int(management.EventTypeKey, int(management.Cloudflared)).
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (ea *EdgeAddr) String() string {
|
func (ea *EdgeAddr) String() string {
|
||||||
@@ -25,8 +26,8 @@ func TestEdgeDiscovery(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
l := zerolog.Nop()
|
l := zerolog.Nop()
|
||||||
addrLists, err := edgeDiscovery(&l, "")
|
addrLists, err := EdgeDiscovery(&l, "")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
actualAddrSet := map[string]bool{}
|
actualAddrSet := map[string]bool{}
|
||||||
for _, addrs := range addrLists {
|
for _, addrs := range addrLists {
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ type Regions struct {
|
|||||||
|
|
||||||
// ResolveEdge resolves the Cloudflare edge, returning all regions discovered.
|
// ResolveEdge resolves the Cloudflare edge, returning all regions discovered.
|
||||||
func ResolveEdge(log *zerolog.Logger, region string, overrideIPVersion ConfigIPVersion) (*Regions, error) {
|
func ResolveEdge(log *zerolog.Logger, region string, overrideIPVersion ConfigIPVersion) (*Regions, error) {
|
||||||
edgeAddrs, err := edgeDiscovery(log, getRegionalServiceName(region))
|
edgeAddrs, err := EdgeDiscovery(log, RegionalServiceName(region))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -91,6 +91,7 @@ func (rs *Regions) GetUnusedAddr(excluding *EdgeAddr, connID int) *EdgeAddr {
|
|||||||
// evenly across both regions.
|
// evenly across both regions.
|
||||||
if rs.region1.AvailableAddrs() == rs.region2.AvailableAddrs() {
|
if rs.region1.AvailableAddrs() == rs.region2.AvailableAddrs() {
|
||||||
regions := []Region{rs.region1, rs.region2}
|
regions := []Region{rs.region1, rs.region2}
|
||||||
|
//nolint:gosec
|
||||||
firstChoice := rand.Intn(2)
|
firstChoice := rand.Intn(2)
|
||||||
return getAddrs(excluding, connID, ®ions[firstChoice], ®ions[1-firstChoice])
|
return getAddrs(excluding, connID, ®ions[firstChoice], ®ions[1-firstChoice])
|
||||||
}
|
}
|
||||||
@@ -131,11 +132,13 @@ func (rs *Regions) GiveBack(addr *EdgeAddr, hasConnectivityError bool) bool {
|
|||||||
return rs.region2.GiveBack(addr, hasConnectivityError)
|
return rs.region2.GiveBack(addr, hasConnectivityError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return regionalized service name if `region` isn't empty, otherwise return the global service name for origintunneld
|
// RegionalServiceName returns the SRV service name for the given region.
|
||||||
func getRegionalServiceName(region string) string {
|
// When region is empty it returns the global service name ("v2-origintunneld").
|
||||||
|
// Otherwise, it prepends the region, e.g. "us-v2-origintunneld".
|
||||||
|
func RegionalServiceName(region string) string {
|
||||||
if region != "" {
|
if region != "" {
|
||||||
return region + "-" + srvService // Example: `us-v2-origintunneld`
|
return region + "-" + srvService
|
||||||
}
|
}
|
||||||
|
|
||||||
return srvService // Global service is just `v2-origintunneld`
|
return srvService
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -237,21 +237,19 @@ func TestNewNoResolveBalancesRegions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetRegionalServiceName(t *testing.T) {
|
func TestRegionalServiceName(t *testing.T) {
|
||||||
// Empty region should just go to origintunneld
|
// Empty region should just go to origintunneld
|
||||||
globalServiceName := getRegionalServiceName("")
|
assert.Equal(t, srvService, RegionalServiceName(""))
|
||||||
assert.Equal(t, srvService, globalServiceName)
|
|
||||||
|
|
||||||
// Non-empty region should go to the regional origintunneld variant
|
// Non-empty region should go to the regional origintunneld variant
|
||||||
for _, region := range []string{"us", "pt", "am"} {
|
for _, region := range []string{"us", "pt", "am"} {
|
||||||
regionalServiceName := getRegionalServiceName(region)
|
assert.Equal(t, region+"-"+srvService, RegionalServiceName(region))
|
||||||
assert.Equal(t, region+"-"+srvService, regionalServiceName)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegionsIsBalanced(t *testing.T, rs *Regions) {
|
func RegionsIsBalanced(t *testing.T, rs *Regions) {
|
||||||
delta := rs.region1.AvailableAddrs() - rs.region2.AvailableAddrs()
|
delta := rs.region1.AvailableAddrs() - rs.region2.AvailableAddrs()
|
||||||
assert.True(t, abs(delta) <= 1)
|
assert.LessOrEqual(t, abs(delta), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func abs(x int) int {
|
func abs(x int) int {
|
||||||
|
|||||||
@@ -0,0 +1,278 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: ../prechecks/resolvers.go
|
||||||
|
//
|
||||||
|
// Generated by this command:
|
||||||
|
//
|
||||||
|
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_resolvers.go -source=../prechecks/resolvers.go
|
||||||
|
//
|
||||||
|
|
||||||
|
// Package mocks is a generated GoMock package.
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
tls "crypto/tls"
|
||||||
|
net "net"
|
||||||
|
netip "net/netip"
|
||||||
|
reflect "reflect"
|
||||||
|
time "time"
|
||||||
|
|
||||||
|
quic "github.com/quic-go/quic-go"
|
||||||
|
zerolog "github.com/rs/zerolog"
|
||||||
|
gomock "go.uber.org/mock/gomock"
|
||||||
|
|
||||||
|
dialopts "github.com/cloudflare/cloudflared/connection/dialopts"
|
||||||
|
allregions "github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockDNSResolver is a mock of DNSResolver interface.
|
||||||
|
type MockDNSResolver struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockDNSResolverMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockDNSResolverMockRecorder is the mock recorder for MockDNSResolver.
|
||||||
|
type MockDNSResolverMockRecorder struct {
|
||||||
|
mock *MockDNSResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockDNSResolver creates a new mock instance.
|
||||||
|
func NewMockDNSResolver(ctrl *gomock.Controller) *MockDNSResolver {
|
||||||
|
mock := &MockDNSResolver{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockDNSResolverMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockDNSResolver) EXPECT() *MockDNSResolverMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve mocks base method.
|
||||||
|
func (m *MockDNSResolver) Resolve(region string) ([][]*allregions.EdgeAddr, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Resolve", region)
|
||||||
|
ret0, _ := ret[0].([][]*allregions.EdgeAddr)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve indicates an expected call of Resolve.
|
||||||
|
func (mr *MockDNSResolverMockRecorder) Resolve(region any) *MockDNSResolverResolveCall {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resolve", reflect.TypeOf((*MockDNSResolver)(nil).Resolve), region)
|
||||||
|
return &MockDNSResolverResolveCall{Call: call}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockDNSResolverResolveCall wrap *gomock.Call
|
||||||
|
type MockDNSResolverResolveCall struct {
|
||||||
|
*gomock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return rewrite *gomock.Call.Return
|
||||||
|
func (c *MockDNSResolverResolveCall) Return(arg0 [][]*allregions.EdgeAddr, arg1 error) *MockDNSResolverResolveCall {
|
||||||
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do rewrite *gomock.Call.Do
|
||||||
|
func (c *MockDNSResolverResolveCall) Do(f func(string) ([][]*allregions.EdgeAddr, error)) *MockDNSResolverResolveCall {
|
||||||
|
c.Call = c.Call.Do(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
|
func (c *MockDNSResolverResolveCall) DoAndReturn(f func(string) ([][]*allregions.EdgeAddr, error)) *MockDNSResolverResolveCall {
|
||||||
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockTCPDialer is a mock of TCPDialer interface.
|
||||||
|
type MockTCPDialer struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockTCPDialerMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockTCPDialerMockRecorder is the mock recorder for MockTCPDialer.
|
||||||
|
type MockTCPDialerMockRecorder struct {
|
||||||
|
mock *MockTCPDialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockTCPDialer creates a new mock instance.
|
||||||
|
func NewMockTCPDialer(ctrl *gomock.Controller) *MockTCPDialer {
|
||||||
|
mock := &MockTCPDialer{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockTCPDialerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockTCPDialer) EXPECT() *MockTCPDialerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialEdge mocks base method.
|
||||||
|
func (m *MockTCPDialer) DialEdge(ctx context.Context, timeout time.Duration, tlsConfig *tls.Config, addr *net.TCPAddr, localIP net.IP) (net.Conn, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DialEdge", ctx, timeout, tlsConfig, addr, localIP)
|
||||||
|
ret0, _ := ret[0].(net.Conn)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialEdge indicates an expected call of DialEdge.
|
||||||
|
func (mr *MockTCPDialerMockRecorder) DialEdge(ctx, timeout, tlsConfig, addr, localIP any) *MockTCPDialerDialEdgeCall {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialEdge", reflect.TypeOf((*MockTCPDialer)(nil).DialEdge), ctx, timeout, tlsConfig, addr, localIP)
|
||||||
|
return &MockTCPDialerDialEdgeCall{Call: call}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockTCPDialerDialEdgeCall wrap *gomock.Call
|
||||||
|
type MockTCPDialerDialEdgeCall struct {
|
||||||
|
*gomock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return rewrite *gomock.Call.Return
|
||||||
|
func (c *MockTCPDialerDialEdgeCall) Return(arg0 net.Conn, arg1 error) *MockTCPDialerDialEdgeCall {
|
||||||
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do rewrite *gomock.Call.Do
|
||||||
|
func (c *MockTCPDialerDialEdgeCall) Do(f func(context.Context, time.Duration, *tls.Config, *net.TCPAddr, net.IP) (net.Conn, error)) *MockTCPDialerDialEdgeCall {
|
||||||
|
c.Call = c.Call.Do(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
|
func (c *MockTCPDialerDialEdgeCall) DoAndReturn(f func(context.Context, time.Duration, *tls.Config, *net.TCPAddr, net.IP) (net.Conn, error)) *MockTCPDialerDialEdgeCall {
|
||||||
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockQUICDialer is a mock of QUICDialer interface.
|
||||||
|
type MockQUICDialer struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockQUICDialerMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockQUICDialerMockRecorder is the mock recorder for MockQUICDialer.
|
||||||
|
type MockQUICDialerMockRecorder struct {
|
||||||
|
mock *MockQUICDialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockQUICDialer creates a new mock instance.
|
||||||
|
func NewMockQUICDialer(ctrl *gomock.Controller) *MockQUICDialer {
|
||||||
|
mock := &MockQUICDialer{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockQUICDialerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockQUICDialer) EXPECT() *MockQUICDialerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialQuic mocks base method.
|
||||||
|
func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.Connection, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DialQuic", ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
|
||||||
|
ret0, _ := ret[0].(quic.Connection)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialQuic indicates an expected call of DialQuic.
|
||||||
|
func (mr *MockQUICDialerMockRecorder) DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts any) *MockQUICDialerDialQuicCall {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialQuic", reflect.TypeOf((*MockQUICDialer)(nil).DialQuic), ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
|
||||||
|
return &MockQUICDialerDialQuicCall{Call: call}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockQUICDialerDialQuicCall wrap *gomock.Call
|
||||||
|
type MockQUICDialerDialQuicCall struct {
|
||||||
|
*gomock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return rewrite *gomock.Call.Return
|
||||||
|
func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.Connection, arg1 error) *MockQUICDialerDialQuicCall {
|
||||||
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do rewrite *gomock.Call.Do
|
||||||
|
func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
|
||||||
|
c.Call = c.Call.Do(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
|
func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
|
||||||
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockManagementDialer is a mock of ManagementDialer interface.
|
||||||
|
type MockManagementDialer struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockManagementDialerMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockManagementDialerMockRecorder is the mock recorder for MockManagementDialer.
|
||||||
|
type MockManagementDialerMockRecorder struct {
|
||||||
|
mock *MockManagementDialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockManagementDialer creates a new mock instance.
|
||||||
|
func NewMockManagementDialer(ctrl *gomock.Controller) *MockManagementDialer {
|
||||||
|
mock := &MockManagementDialer{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockManagementDialerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockManagementDialer) EXPECT() *MockManagementDialerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext mocks base method.
|
||||||
|
func (m *MockManagementDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DialContext", ctx, network, addr)
|
||||||
|
ret0, _ := ret[0].(net.Conn)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext indicates an expected call of DialContext.
|
||||||
|
func (mr *MockManagementDialerMockRecorder) DialContext(ctx, network, addr any) *MockManagementDialerDialContextCall {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockManagementDialer)(nil).DialContext), ctx, network, addr)
|
||||||
|
return &MockManagementDialerDialContextCall{Call: call}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockManagementDialerDialContextCall wrap *gomock.Call
|
||||||
|
type MockManagementDialerDialContextCall struct {
|
||||||
|
*gomock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return rewrite *gomock.Call.Return
|
||||||
|
func (c *MockManagementDialerDialContextCall) Return(arg0 net.Conn, arg1 error) *MockManagementDialerDialContextCall {
|
||||||
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do rewrite *gomock.Call.Do
|
||||||
|
func (c *MockManagementDialerDialContextCall) Do(f func(context.Context, string, string) (net.Conn, error)) *MockManagementDialerDialContextCall {
|
||||||
|
c.Call = c.Call.Do(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
|
func (c *MockManagementDialerDialContextCall) DoAndReturn(f func(context.Context, string, string) (net.Conn, error)) *MockManagementDialerDialContextCall {
|
||||||
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
@@ -3,3 +3,5 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"
|
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"
|
||||||
|
|
||||||
|
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_resolvers.go -source=../prechecks/resolvers.go"
|
||||||
|
|||||||
@@ -0,0 +1,338 @@
|
|||||||
|
package prechecks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/quic-go/quic-go"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection/dialopts"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
edgedial "github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
|
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
perProbeDialTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// Action messages for each probe outcome.
|
||||||
|
actionDNSFail = "Ensure your DNS resolver can resolve '%s'. Run: dig A %s @1.1.1.1. If that fails, contact your network administrator."
|
||||||
|
actionQUICBlocked = "QUIC traffic failed to connect to port 7844."
|
||||||
|
actionHTTP2Blocked = "Allow outbound TCP on port 7844."
|
||||||
|
actionAPIUnreachable = "cloudflared will still run, but automatic software updates are unavailable. " +
|
||||||
|
"Ensure port 443 TCP to api.cloudflare.com is open if you want auto-updates."
|
||||||
|
|
||||||
|
// Component names for CheckResult.
|
||||||
|
componentDNSResolution = "DNS Resolution"
|
||||||
|
componentUDPConnectivity = "UDP Connectivity"
|
||||||
|
componentTCPConnectivity = "TCP Connectivity"
|
||||||
|
componentCloudflareAPI = "Cloudflare API"
|
||||||
|
|
||||||
|
// Target identifiers for CheckResult.
|
||||||
|
targetPortQUIC = "Port 7844 (QUIC)"
|
||||||
|
targetPortHTTP2 = "Port 7844 (HTTP/2)"
|
||||||
|
targetAPI = "api.cloudflare.com:443"
|
||||||
|
|
||||||
|
// Details messages for CheckResult.
|
||||||
|
detailsNoAddressesReturned = "No addresses returned"
|
||||||
|
detailsResolvedSuccessfully = "Resolved successfully"
|
||||||
|
detailsHandshakeFailed = "Handshake failed"
|
||||||
|
detailsHandshakeSuccessful = "Handshake successful"
|
||||||
|
detailsBlockedOrUnreachable = "Blocked or unreachable"
|
||||||
|
detailsTLSHandshakeSuccessful = "TLS handshake successful"
|
||||||
|
detailsConnectionFailed = "Connection failed"
|
||||||
|
detailsTCPPortReachable = "TCP port reachable (TLS not validated)"
|
||||||
|
detailsDNSPrerequisiteFailed = "DNS prerequisite failed"
|
||||||
|
|
||||||
|
// Region hostname templates.
|
||||||
|
region1Global = "region1.v2.argotunnel.com"
|
||||||
|
region2Global = "region2.v2.argotunnel.com"
|
||||||
|
region1US = "us-region1.v2.argotunnel.com"
|
||||||
|
region2US = "us-region2.v2.argotunnel.com"
|
||||||
|
region1Fed = "fed-region1.v2.argotunnel.com"
|
||||||
|
region2Fed = "fed-region2.v2.argotunnel.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EdgeDNSResolver struct {
|
||||||
|
Log *zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EdgeDNSResolver) Resolve(region string) ([][]*allregions.EdgeAddr, error) {
|
||||||
|
return allregions.EdgeDiscovery(r.Log, allregions.RegionalServiceName(region))
|
||||||
|
}
|
||||||
|
|
||||||
|
type EdgeTCPDialer struct{}
|
||||||
|
|
||||||
|
func (d *EdgeTCPDialer) DialEdge(
|
||||||
|
ctx context.Context,
|
||||||
|
timeout time.Duration,
|
||||||
|
tlsConfig *tls.Config,
|
||||||
|
addr *net.TCPAddr,
|
||||||
|
localIP net.IP,
|
||||||
|
) (net.Conn, error) {
|
||||||
|
return edgedial.DialEdge(ctx, timeout, tlsConfig, addr, localIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EdgeQUICDialer struct{}
|
||||||
|
|
||||||
|
func (d *EdgeQUICDialer) DialQuic(
|
||||||
|
ctx context.Context,
|
||||||
|
quicConfig *quic.Config,
|
||||||
|
tlsConfig *tls.Config,
|
||||||
|
addr netip.AddrPort,
|
||||||
|
localAddr net.IP,
|
||||||
|
connIndex uint8,
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
opts dialopts.DialOpts,
|
||||||
|
) (quic.Connection, error) {
|
||||||
|
return connection.DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetManagementDialer struct {
|
||||||
|
Dialer net.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *NetManagementDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return d.Dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// probeDNS resolves edge addresses for the given region via the supplied
|
||||||
|
// DNSResolver and returns a CheckResult for each region discovered. If
|
||||||
|
// resolution fails for all regions, every result will carry StatusFail.
|
||||||
|
func probeDNS(
|
||||||
|
resolver DNSResolver,
|
||||||
|
region string,
|
||||||
|
) ([][]*allregions.EdgeAddr, []CheckResult) {
|
||||||
|
addrGroups, err := resolver.Resolve(region)
|
||||||
|
if err != nil || len(addrGroups) == 0 {
|
||||||
|
detail := detailsNoAddressesReturned
|
||||||
|
if err != nil {
|
||||||
|
detail = err.Error()
|
||||||
|
}
|
||||||
|
region1Target, region2Target := regionTargets(region)
|
||||||
|
return nil, []CheckResult{
|
||||||
|
{
|
||||||
|
Type: ProbeTypeDNS,
|
||||||
|
Component: componentDNSResolution,
|
||||||
|
Target: region1Target,
|
||||||
|
ProbeStatus: Fail,
|
||||||
|
Details: detail,
|
||||||
|
Action: fmt.Sprintf(actionDNSFail, region1Target, region1Target),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: ProbeTypeDNS,
|
||||||
|
Component: componentDNSResolution,
|
||||||
|
Target: region2Target,
|
||||||
|
ProbeStatus: Fail,
|
||||||
|
Details: detail,
|
||||||
|
Action: fmt.Sprintf(actionDNSFail, region2Target, region2Target),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
region1Target, region2Target := regionTargets(region)
|
||||||
|
targets := []string{region1Target, region2Target}
|
||||||
|
|
||||||
|
results := make([]CheckResult, 0, len(addrGroups))
|
||||||
|
for i, group := range addrGroups {
|
||||||
|
target := fmt.Sprintf("region%d.v2.argotunnel.com", i+1)
|
||||||
|
if i < len(targets) {
|
||||||
|
target = targets[i]
|
||||||
|
}
|
||||||
|
if len(group) == 0 {
|
||||||
|
results = append(results, CheckResult{
|
||||||
|
Type: ProbeTypeDNS,
|
||||||
|
Component: componentDNSResolution,
|
||||||
|
Target: target,
|
||||||
|
ProbeStatus: Fail,
|
||||||
|
Details: detailsNoAddressesReturned,
|
||||||
|
Action: fmt.Sprintf(actionDNSFail, target, target),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
results = append(results, CheckResult{
|
||||||
|
Type: ProbeTypeDNS,
|
||||||
|
Component: componentDNSResolution,
|
||||||
|
Target: target,
|
||||||
|
ProbeStatus: Pass,
|
||||||
|
Details: detailsResolvedSuccessfully,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addrGroups, results
|
||||||
|
}
|
||||||
|
|
||||||
|
// probeQUIC performs a QUIC handshake to a single edge address and returns a
|
||||||
|
// CheckResult. The connection is closed immediately after the handshake – no
|
||||||
|
// streams are opened and no RPC frames are sent – to avoid triggering the OTD
|
||||||
|
// registration timeout (TUN-6732). The probe SNI (probe.cftunnel.com) is used
|
||||||
|
// instead of the production quic.cftunnel.com to prevent OTD log noise.
|
||||||
|
//
|
||||||
|
// A per-probe deadline (perProbeDialTimeout) is applied on top of the parent
|
||||||
|
// context so that a single blocked handshake cannot consume the entire suite
|
||||||
|
// budget.
|
||||||
|
func probeQUIC(
|
||||||
|
ctx context.Context,
|
||||||
|
dialer QUICDialer,
|
||||||
|
addr *allregions.EdgeAddr,
|
||||||
|
logger *zerolog.Logger,
|
||||||
|
) CheckResult {
|
||||||
|
dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tlsSettings := connection.QUIC.ProbeTLSSettings()
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: tlsSettings.ServerName,
|
||||||
|
NextProtos: tlsSettings.NextProtos,
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
CurvePreferences: []tls.CurveID{tls.CurveP256},
|
||||||
|
}
|
||||||
|
|
||||||
|
// We call dialer.DialQuic with isProbe = true, which bypasses connIndex check.
|
||||||
|
// Therefore, whatever we add to connIndex will not be relevant.
|
||||||
|
edgeAddrPort := addr.UDP.AddrPort()
|
||||||
|
conn, err := dialer.DialQuic(
|
||||||
|
dialCtx,
|
||||||
|
&quic.Config{},
|
||||||
|
tlsConfig,
|
||||||
|
edgeAddrPort,
|
||||||
|
nil,
|
||||||
|
math.MaxUint8,
|
||||||
|
logger,
|
||||||
|
dialopts.DialOpts{SkipPortReuse: true},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return CheckResult{
|
||||||
|
Type: ProbeTypeQUIC,
|
||||||
|
Component: componentUDPConnectivity,
|
||||||
|
Target: targetPortQUIC,
|
||||||
|
ProbeStatus: Fail,
|
||||||
|
Details: detailsHandshakeFailed,
|
||||||
|
Action: actionQUICBlocked,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.CloseWithError(0, "precheck complete"); err != nil {
|
||||||
|
logger.Debug().Err(err).Msg("Failed to close QUIC connection after successful handshake")
|
||||||
|
}
|
||||||
|
|
||||||
|
return CheckResult{
|
||||||
|
Type: ProbeTypeQUIC,
|
||||||
|
Component: componentUDPConnectivity,
|
||||||
|
Target: targetPortQUIC,
|
||||||
|
ProbeStatus: Pass,
|
||||||
|
Details: detailsHandshakeSuccessful,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// probeHTTP2 performs a TCP + TLS handshake to a single edge address and
|
||||||
|
// returns a CheckResult. The connection is closed immediately after the
|
||||||
|
// handshake – no HTTP/2 frames are sent – to keep the probe minimal. The probe
|
||||||
|
// SNI (probe.cftunnel.com) is used instead of the production h2.cftunnel.com
|
||||||
|
// to prevent OTD log noise.
|
||||||
|
//
|
||||||
|
// The dial timeout is capped at perProbeDialTimeout so that a single blocked
|
||||||
|
// dial cannot exhaust the entire suite budget.
|
||||||
|
func probeHTTP2(ctx context.Context, dialer TCPDialer, addr *allregions.EdgeAddr) CheckResult {
|
||||||
|
tlsSettings := connection.HTTP2.ProbeTLSSettings()
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: tlsSettings.ServerName,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
CurvePreferences: []tls.CurveID{tls.CurveP256},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialer.DialEdge(ctx, perProbeDialTimeout, tlsConfig, addr.TCP, nil)
|
||||||
|
if err != nil {
|
||||||
|
return CheckResult{
|
||||||
|
Type: ProbeTypeHTTP2,
|
||||||
|
Component: componentTCPConnectivity,
|
||||||
|
Target: targetPortHTTP2,
|
||||||
|
ProbeStatus: Fail,
|
||||||
|
Details: detailsBlockedOrUnreachable,
|
||||||
|
Action: actionHTTP2Blocked,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = conn.Close()
|
||||||
|
|
||||||
|
return CheckResult{
|
||||||
|
Type: ProbeTypeHTTP2,
|
||||||
|
Component: componentTCPConnectivity,
|
||||||
|
Target: targetPortHTTP2,
|
||||||
|
ProbeStatus: Pass,
|
||||||
|
Details: detailsTLSHandshakeSuccessful,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// probeManagementAPI tests TCP connectivity to api.cloudflare.com:443. A
|
||||||
|
// successful TCP connection (no TLS handshake required) confirms the port is
|
||||||
|
// reachable. This probe is always a soft failure: the tunnel can run without
|
||||||
|
// it, but automatic software updates will be unavailable.
|
||||||
|
func probeManagementAPI(ctx context.Context, dialer ManagementDialer) CheckResult {
|
||||||
|
dialCtx, cancel := context.WithTimeout(ctx, perProbeDialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := dialer.DialContext(dialCtx, "tcp", targetAPI)
|
||||||
|
if err != nil {
|
||||||
|
return CheckResult{
|
||||||
|
Type: ProbeTypeManagementAPI,
|
||||||
|
Component: componentCloudflareAPI,
|
||||||
|
Target: targetAPI,
|
||||||
|
ProbeStatus: Fail,
|
||||||
|
Details: detailsConnectionFailed,
|
||||||
|
Action: actionAPIUnreachable,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = conn.Close()
|
||||||
|
|
||||||
|
return CheckResult{
|
||||||
|
Type: ProbeTypeManagementAPI,
|
||||||
|
Component: componentCloudflareAPI,
|
||||||
|
Target: targetAPI,
|
||||||
|
ProbeStatus: Pass,
|
||||||
|
Details: detailsTCPPortReachable,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipResult(probeType ProbeType, component, target string) CheckResult {
|
||||||
|
return CheckResult{
|
||||||
|
Type: probeType,
|
||||||
|
Component: component,
|
||||||
|
Target: target,
|
||||||
|
ProbeStatus: Skip,
|
||||||
|
Details: detailsDNSPrerequisiteFailed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// regionTargets returns the human-readable hostnames for region1 and region2
|
||||||
|
// based on the optional region flag value.
|
||||||
|
func regionTargets(region string) (string, string) {
|
||||||
|
switch region {
|
||||||
|
case "us":
|
||||||
|
return region1US, region2US
|
||||||
|
case "fed":
|
||||||
|
return region1Fed, region2Fed
|
||||||
|
default:
|
||||||
|
return region1Global, region2Global
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addrsByFamily extracts one V4 and one V6 address from a resolved CNAME group
|
||||||
|
// using allregions.NewRegion so that the IP-version preference logic matches
|
||||||
|
// production exactly. When cfg.IPVersion restricts to a single family the
|
||||||
|
// excluded family's pointer is nil.
|
||||||
|
func addrsByFamily(group []*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) (v4, v6 *allregions.EdgeAddr) {
|
||||||
|
if ipVersion != allregions.IPv6Only {
|
||||||
|
v4 = allregions.NewRegion(group, allregions.IPv4Only).GetAnyAddress()
|
||||||
|
}
|
||||||
|
if ipVersion != allregions.IPv4Only {
|
||||||
|
v6 = allregions.NewRegion(group, allregions.IPv6Only).GetAnyAddress()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -0,0 +1,536 @@
|
|||||||
|
package prechecks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/quic-go/quic-go"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/mock/gomock"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||||
|
"github.com/cloudflare/cloudflared/mocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test constants for repeated string values.
|
||||||
|
const (
|
||||||
|
testRegion1Global = region1Global
|
||||||
|
testRegion2Global = region2Global
|
||||||
|
testRegion1US = region1US
|
||||||
|
testRegion2US = region2US
|
||||||
|
testRegion1Fed = region1Fed
|
||||||
|
testRegion2Fed = region2Fed
|
||||||
|
testRegion1EU = "eu-region1.v2.argotunnel.com"
|
||||||
|
testRegion2EU = "eu-region2.v2.argotunnel.com"
|
||||||
|
|
||||||
|
testEdgePort = 7844
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockQuicConnection is a minimal test double for quic.Connection.
|
||||||
|
type mockQuicConnection struct {
|
||||||
|
closeErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) OpenStream() (quic.Stream, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) LocalAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) RemoteAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, _ string) error {
|
||||||
|
return m.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) Context() context.Context {
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) ConnectionState() quic.ConnectionState {
|
||||||
|
return quic.ConnectionState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) SendDatagram(_ []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockQuicConnection) AddPath(*quic.Transport) (*quic.Path, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to create test edge addresses.
|
||||||
|
func createTestEdgeAddr(ip string, port int, version allregions.EdgeIPVersion) *allregions.EdgeAddr {
|
||||||
|
parsedIP := net.ParseIP(ip)
|
||||||
|
return &allregions.EdgeAddr{
|
||||||
|
TCP: &net.TCPAddr{IP: parsedIP, Port: port},
|
||||||
|
UDP: &net.UDPAddr{IP: parsedIP, Port: port},
|
||||||
|
IPVersion: version,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// probeDNS tests.
|
||||||
|
|
||||||
|
func TestProbeDNS_Success(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
v6Addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
|
||||||
|
|
||||||
|
resolver := mocks.NewMockDNSResolver(ctrl)
|
||||||
|
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr, v6Addr}}, nil)
|
||||||
|
|
||||||
|
addrs, results := probeDNS(resolver, "")
|
||||||
|
|
||||||
|
require.NotNil(t, addrs)
|
||||||
|
require.Len(t, results, 1)
|
||||||
|
assert.Len(t, addrs, 1)
|
||||||
|
assert.Equal(t, ProbeTypeDNS, results[0].Type)
|
||||||
|
assert.Equal(t, testRegion1Global, results[0].Target)
|
||||||
|
assert.Equal(t, Pass, results[0].ProbeStatus)
|
||||||
|
assert.Equal(t, detailsResolvedSuccessfully, results[0].Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeDNS_MultipleRegions(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
v4Addr1 := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
v4Addr2 := createTestEdgeAddr("192.0.2.2", testEdgePort, allregions.V4)
|
||||||
|
|
||||||
|
resolver := mocks.NewMockDNSResolver(ctrl)
|
||||||
|
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr1}, {v4Addr2}}, nil)
|
||||||
|
|
||||||
|
addrs, results := probeDNS(resolver, "")
|
||||||
|
|
||||||
|
require.NotNil(t, addrs)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
assert.Len(t, addrs, 2)
|
||||||
|
|
||||||
|
assert.Equal(t, testRegion1Global, results[0].Target)
|
||||||
|
assert.Equal(t, Pass, results[0].ProbeStatus)
|
||||||
|
|
||||||
|
assert.Equal(t, testRegion2Global, results[1].Target)
|
||||||
|
assert.Equal(t, Pass, results[1].ProbeStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeDNS_ResolverError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
resolver := mocks.NewMockDNSResolver(ctrl)
|
||||||
|
resolver.EXPECT().Resolve("").Return(nil, errors.New("DNS lookup failed"))
|
||||||
|
|
||||||
|
addrs, results := probeDNS(resolver, "")
|
||||||
|
|
||||||
|
assert.Nil(t, addrs)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
|
||||||
|
assert.Equal(t, Fail, results[0].ProbeStatus)
|
||||||
|
assert.Equal(t, "DNS lookup failed", results[0].Details)
|
||||||
|
assert.Contains(t, results[0].Action, testRegion1Global)
|
||||||
|
assert.Contains(t, results[1].Action, testRegion2Global)
|
||||||
|
|
||||||
|
assert.Equal(t, Fail, results[1].ProbeStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeDNS_EmptyResults(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
resolver := mocks.NewMockDNSResolver(ctrl)
|
||||||
|
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{}, nil)
|
||||||
|
|
||||||
|
addrs, results := probeDNS(resolver, "")
|
||||||
|
|
||||||
|
assert.Nil(t, addrs)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
assert.Equal(t, Fail, results[0].ProbeStatus)
|
||||||
|
assert.Equal(t, "No addresses returned", results[0].Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeDNS_EmptyGroup(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
resolver := mocks.NewMockDNSResolver(ctrl)
|
||||||
|
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{}}, nil)
|
||||||
|
|
||||||
|
addrs, results := probeDNS(resolver, "")
|
||||||
|
|
||||||
|
require.NotNil(t, addrs)
|
||||||
|
require.Len(t, results, 1)
|
||||||
|
assert.Equal(t, Fail, results[0].ProbeStatus)
|
||||||
|
assert.Equal(t, "No addresses returned", results[0].Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeDNS_RegionFlag(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
resolver := mocks.NewMockDNSResolver(ctrl)
|
||||||
|
resolver.EXPECT().Resolve("us").Return([][]*allregions.EdgeAddr{{v4Addr}}, nil)
|
||||||
|
|
||||||
|
_, results := probeDNS(resolver, "us")
|
||||||
|
|
||||||
|
require.Len(t, results, 1)
|
||||||
|
assert.Equal(t, testRegion1US, results[0].Target)
|
||||||
|
}
|
||||||
|
|
||||||
|
// probeQUIC tests.
|
||||||
|
|
||||||
|
func TestProbeQUIC_Success(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockConn := &mockQuicConnection{}
|
||||||
|
dialer := mocks.NewMockQUICDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
|
||||||
|
|
||||||
|
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
logger := zerolog.New(nil)
|
||||||
|
|
||||||
|
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||||
|
|
||||||
|
assert.Equal(t, ProbeTypeQUIC, result.Type)
|
||||||
|
assert.Equal(t, Pass, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeQUIC_DialError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
dialer := mocks.NewMockQUICDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("connection refused"))
|
||||||
|
|
||||||
|
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
logger := zerolog.New(nil)
|
||||||
|
|
||||||
|
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||||
|
|
||||||
|
assert.Equal(t, ProbeTypeQUIC, result.Type)
|
||||||
|
assert.Equal(t, Fail, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsHandshakeFailed, result.Details)
|
||||||
|
assert.Equal(t, actionQUICBlocked, result.Action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeQUIC_CloseErrorDoesNotAffectResult(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockConn := &mockQuicConnection{closeErr: errors.New("close failed")}
|
||||||
|
dialer := mocks.NewMockQUICDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
|
||||||
|
|
||||||
|
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
logger := zerolog.New(nil)
|
||||||
|
|
||||||
|
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||||
|
|
||||||
|
assert.Equal(t, ProbeTypeQUIC, result.Type)
|
||||||
|
assert.Equal(t, Pass, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeQUIC_ContextTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
dialer := mocks.NewMockQUICDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, context.DeadlineExceeded)
|
||||||
|
|
||||||
|
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
logger := zerolog.New(nil)
|
||||||
|
|
||||||
|
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||||
|
|
||||||
|
assert.Equal(t, Fail, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsHandshakeFailed, result.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
// probeHTTP2 tests.
|
||||||
|
|
||||||
|
func TestProbeHTTP2_Success(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
dialer := mocks.NewMockTCPDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&net.TCPConn{}, nil)
|
||||||
|
|
||||||
|
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
|
||||||
|
result := probeHTTP2(context.Background(), dialer, addr)
|
||||||
|
|
||||||
|
assert.Equal(t, ProbeTypeHTTP2, result.Type)
|
||||||
|
assert.Equal(t, Pass, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsTLSHandshakeSuccessful, result.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeHTTP2_DialError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
dialer := mocks.NewMockTCPDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("connection refused"))
|
||||||
|
|
||||||
|
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
|
||||||
|
result := probeHTTP2(context.Background(), dialer, addr)
|
||||||
|
|
||||||
|
assert.Equal(t, ProbeTypeHTTP2, result.Type)
|
||||||
|
assert.Equal(t, Fail, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsBlockedOrUnreachable, result.Details)
|
||||||
|
assert.Equal(t, actionHTTP2Blocked, result.Action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// probeManagementAPI tests.
|
||||||
|
|
||||||
|
func TestProbeManagementAPI_Success(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
dialer := mocks.NewMockManagementDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialContext(gomock.Any(), "tcp", "api.cloudflare.com:443").Return(&net.TCPConn{}, nil)
|
||||||
|
|
||||||
|
result := probeManagementAPI(context.Background(), dialer)
|
||||||
|
|
||||||
|
assert.Equal(t, ProbeTypeManagementAPI, result.Type)
|
||||||
|
assert.Equal(t, "Cloudflare API", result.Component)
|
||||||
|
assert.Equal(t, "api.cloudflare.com:443", result.Target)
|
||||||
|
assert.Equal(t, Pass, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsTCPPortReachable, result.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeManagementAPI_DialError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
dialer := mocks.NewMockManagementDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialContext(gomock.Any(), "tcp", "api.cloudflare.com:443").Return(nil, errors.New("connection refused"))
|
||||||
|
|
||||||
|
result := probeManagementAPI(context.Background(), dialer)
|
||||||
|
|
||||||
|
assert.Equal(t, ProbeTypeManagementAPI, result.Type)
|
||||||
|
assert.Equal(t, Fail, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsConnectionFailed, result.Details)
|
||||||
|
assert.Equal(t, actionAPIUnreachable, result.Action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// skipResult tests.
|
||||||
|
|
||||||
|
func TestSkipResult(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
result := skipResult(ProbeTypeQUIC, "UDP Connectivity", "Port 7844 (QUIC)")
|
||||||
|
|
||||||
|
assert.Equal(t, ProbeTypeQUIC, result.Type)
|
||||||
|
assert.Equal(t, "UDP Connectivity", result.Component)
|
||||||
|
assert.Equal(t, "Port 7844 (QUIC)", result.Target)
|
||||||
|
assert.Equal(t, Skip, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsDNSPrerequisiteFailed, result.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
// regionTargets tests.
|
||||||
|
|
||||||
|
func TestRegionTargets(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
region string
|
||||||
|
wantRegion1 string
|
||||||
|
wantRegion2 string
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty region returns global hostnames",
|
||||||
|
region: "",
|
||||||
|
wantRegion1: testRegion1Global,
|
||||||
|
wantRegion2: testRegion2Global,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "us region returns US hostnames",
|
||||||
|
region: "us",
|
||||||
|
wantRegion1: testRegion1US,
|
||||||
|
wantRegion2: testRegion2US,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fed region returns fed hostnames",
|
||||||
|
region: "fed",
|
||||||
|
wantRegion1: testRegion1Fed,
|
||||||
|
wantRegion2: testRegion2Fed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown region defaults to global hostnames",
|
||||||
|
region: "eu",
|
||||||
|
wantRegion1: testRegion1Global,
|
||||||
|
wantRegion2: testRegion2Global,
|
||||||
|
description: "Unknown regions should default to global hostnames",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gotR1, gotR2 := regionTargets(tt.region)
|
||||||
|
assert.Equal(t, tt.wantRegion1, gotR1)
|
||||||
|
assert.Equal(t, tt.wantRegion2, gotR2)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addrsByFamily tests.
|
||||||
|
|
||||||
|
func TestAddrsByFamily(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
v4Addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
|
||||||
|
v6Addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
group []*allregions.EdgeAddr
|
||||||
|
ipVersion allregions.ConfigIPVersion
|
||||||
|
wantV4 bool
|
||||||
|
wantV6 bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "auto returns both v4 and v6",
|
||||||
|
group: []*allregions.EdgeAddr{v4Addr, v6Addr},
|
||||||
|
ipVersion: allregions.Auto,
|
||||||
|
wantV4: true,
|
||||||
|
wantV6: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4 only returns v4 and nil v6",
|
||||||
|
group: []*allregions.EdgeAddr{v4Addr, v6Addr},
|
||||||
|
ipVersion: allregions.IPv4Only,
|
||||||
|
wantV4: true,
|
||||||
|
wantV6: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6 only returns nil v4 and v6",
|
||||||
|
group: []*allregions.EdgeAddr{v4Addr, v6Addr},
|
||||||
|
ipVersion: allregions.IPv6Only,
|
||||||
|
wantV4: false,
|
||||||
|
wantV6: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty group returns nil for both",
|
||||||
|
group: []*allregions.EdgeAddr{},
|
||||||
|
ipVersion: allregions.Auto,
|
||||||
|
wantV4: false,
|
||||||
|
wantV6: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only v4 available returns v4 and nil v6",
|
||||||
|
group: []*allregions.EdgeAddr{v4Addr},
|
||||||
|
ipVersion: allregions.Auto,
|
||||||
|
wantV4: true,
|
||||||
|
wantV6: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
gotV4, gotV6 := addrsByFamily(tt.group, tt.ipVersion)
|
||||||
|
if tt.wantV4 {
|
||||||
|
assert.NotNil(t, gotV4)
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, gotV4)
|
||||||
|
}
|
||||||
|
if tt.wantV6 {
|
||||||
|
assert.NotNil(t, gotV6)
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, gotV6)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv6 address tests for probeQUIC.
|
||||||
|
|
||||||
|
func TestProbeQUIC_IPv6Address(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockConn := &mockQuicConnection{}
|
||||||
|
dialer := mocks.NewMockQUICDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
|
||||||
|
|
||||||
|
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
|
||||||
|
logger := zerolog.New(nil)
|
||||||
|
|
||||||
|
result := probeQUIC(context.Background(), dialer, addr, &logger)
|
||||||
|
|
||||||
|
assert.Equal(t, Pass, result.ProbeStatus)
|
||||||
|
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv6 address tests for probeHTTP2.
|
||||||
|
|
||||||
|
func TestProbeHTTP2_IPv6Address(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
dialer := mocks.NewMockTCPDialer(ctrl)
|
||||||
|
dialer.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&net.TCPConn{}, nil)
|
||||||
|
|
||||||
|
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
|
||||||
|
|
||||||
|
result := probeHTTP2(context.Background(), dialer, addr)
|
||||||
|
|
||||||
|
assert.Equal(t, Pass, result.ProbeStatus)
|
||||||
|
}
|
||||||
@@ -23,9 +23,6 @@ import (
|
|||||||
// system resolver fails, and resolves each discovered hostname via
|
// system resolver fails, and resolves each discovered hostname via
|
||||||
// net.LookupIP. The returned slice already has each address tagged with
|
// net.LookupIP. The returned slice already has each address tagged with
|
||||||
// .IPVersion = V4 or V6.
|
// .IPVersion = V4 or V6.
|
||||||
//
|
|
||||||
// Note: allregions.EdgeDiscovery must be exported (currently unexported as
|
|
||||||
// edgeDiscovery) before a production adapter can be wired up.
|
|
||||||
type DNSResolver interface {
|
type DNSResolver interface {
|
||||||
// Resolve performs edge discovery for the given region string (empty for
|
// Resolve performs edge discovery for the given region string (empty for
|
||||||
// global, "us" / "fed" for regional endpoints) and returns the resolved
|
// global, "us" / "fed" for regional endpoints) and returns the resolved
|
||||||
Reference in New Issue
Block a user