diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 9a33fb49..2ed385e2 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -539,19 +539,11 @@ func runPrechecks(c *cli.Context, log *zerolog.Logger, region string) { cfg := prechecks.Config{ Region: region, IPVersion: ipVersion, - } - - // Mirror the static/dynamic edge selection from supervisor/supervisor.go: - // when --edge addresses are provided, bypass DNS discovery entirely. - var dnsResolver prechecks.DNSResolver - if edgeAddrs := c.StringSlice(cfdflags.Edge); len(edgeAddrs) > 0 { - dnsResolver = &prechecks.StaticEdgeDNSResolver{Addrs: edgeAddrs, Log: log} - } else { - dnsResolver = &prechecks.EdgeDNSResolver{Log: log} + EdgeAddrs: c.StringSlice(cfdflags.Edge), } dialers := prechecks.RunDialers{ - DNSResolver: dnsResolver, + DNSResolver: &prechecks.EdgeDNSResolver{Log: log}, TCPDialer: &prechecks.EdgeTCPDialer{}, QUICDialer: &prechecks.EdgeQUICDialer{}, ManagementDialer: &prechecks.NetManagementDialer{Dialer: net.Dialer{}}, diff --git a/prechecks/checker.go b/prechecks/checker.go index 4426dd30..8d98f61f 100644 --- a/prechecks/checker.go +++ b/prechecks/checker.go @@ -30,16 +30,17 @@ type RunDialers struct { ManagementDialer ManagementDialer } -// TransportResults holds the per-region results for each transport probe type. -// Each slice has one entry per DNS-resolved region, in the same order as dnsResults. +// TransportResults holds the per-target results for each transport probe type. +// Each slice has one entry per resolved target group, in the same order as the +// target labels slice. type TransportResults struct { - QUIC []CheckResult // one per region - HTTP2 []CheckResult // one per region - ManagementAPI CheckResult // single target, no regions + QUIC []CheckResult // one per target group + HTTP2 []CheckResult // one per target group + ManagementAPI CheckResult // single target, no groups } // Collect returns all results as a slice in a consistent order for reporting: -// all QUIC rows first (one per region), then all HTTP2 rows, then Management API. +// all QUIC rows first (one per target), then all HTTP2 rows, then Management API. func (tr TransportResults) Collect() []CheckResult { results := make([]CheckResult, 0, len(tr.QUIC)+len(tr.HTTP2)+1) results = append(results, tr.QUIC...) @@ -50,8 +51,11 @@ func (tr TransportResults) Collect() []CheckResult { // Run executes the following connectivity pre-checks: // -// 1. DNS resolution (sequential – transport probes depend on its output). -// 2. QUIC, HTTP/2, and Management API probes run concurrently. +// 1. Edge address resolution — either DNS-based SRV discovery (normal path) +// or direct resolution of --edge addresses (static path). The static path +// skips DNS probe rows entirely since there are no SRV records to validate. +// 2. QUIC, HTTP/2, and Management API probes run concurrently against the +// resolved addresses. // // Each failed probe is retried up to maxRetries times with exponential backoff. // The suite is bounded by cfg.Timeout (defaultTimeout if zero). @@ -64,19 +68,39 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru ctx, cancel := context.WithTimeout(ctx, cfg.Timeout) defer cancel() - // Build TLS configs once per protocol + // Build TLS configs once per protocol. quicTLSConfig, quicTLSErr := probeTLSConfig(caCert, connection.QUIC) http2TLSConfig, http2TLSErr := probeTLSConfig(caCert, connection.HTTP2) - // 1) DNS – must complete before transport probes know which addresses to dial. - addrGroups, dnsResults := runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region) + // 1) Resolve edge addresses. Each ResolvedTarget bundles its addr group + // with the DNS CheckResult that labels it, keeping the two in sync. + var resolvedTargets []ResolvedTarget + if len(cfg.EdgeAddrs) > 0 { + // Static path: explicit --edge addresses, one ResolvedTarget per addr. + resolvedTargets = resolveStaticEdge(cfg.EdgeAddrs, log) + } else { + // Normal path: SRV-based discovery; DNS rows carry Pass or Fail status. + resolvedTargets = runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region) + } - dnsOK := !slices.ContainsFunc(dnsResults, func(r CheckResult) bool { - return r.ProbeStatus != Pass + // Extract parallel slices for the transport probe layer. + // nolint:prealloc // False positive. The linter is confused by the append used when producing Report.Results + dnsResults := make([]CheckResult, len(resolvedTargets)) + perGroupAddrs := make([][]*allregions.EdgeAddr, len(resolvedTargets)) + targetLabels := make([]string, len(resolvedTargets)) + for i, rt := range resolvedTargets { + dnsResults[i] = rt.DNSResult + perGroupAddrs[i] = rt.Addrs + targetLabels[i] = rt.DNSResult.Target + } + + // dnsOK is true when at least one target has addresses to probe. + dnsOK := slices.ContainsFunc(resolvedTargets, func(r ResolvedTarget) bool { + return len(r.Addrs) > 0 }) - // 2) Run probes concurrently. Each probe type gets its own buffered channel — - // one send, one receive, no routing or name-parsing required. + // 2) Run transport probes concurrently. Each probe type gets its own + // buffered channel — one send, one receive, no routing required. var results TransportResults mgmtCh := make(chan CheckResult) @@ -85,12 +109,12 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru }() if !dnsOK { - // DNS failed: emit one skip row per region so the table stays consistent. - results.QUIC = skipResultsForRegions(dnsResults, ProbeTypeQUIC, componentUDPConnectivity) - results.HTTP2 = skipResultsForRegions(dnsResults, ProbeTypeHTTP2, componentTCPConnectivity) + // No addresses available: emit one skip row per target so the table + // stays consistent with the DNS rows above. + results.QUIC = skipResultsForTargets(dnsResults, ProbeTypeQUIC, componentUDPConnectivity) + results.HTTP2 = skipResultsForTargets(dnsResults, ProbeTypeHTTP2, componentTCPConnectivity) } else { - perRegionAddrs := addrsByRegion(addrGroups, cfg.IPVersion) - regionTargets := dnsTargets(dnsResults) + filteredAddrs := addrsByGroup(perGroupAddrs, cfg.IPVersion) quicCh := make(chan []CheckResult, 1) http2Ch := make(chan []CheckResult, 1) @@ -99,11 +123,11 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru if quicTLSErr != nil { log.Warn().Err(quicTLSErr).Msg("Failed to build QUIC probe TLS config") quicCh <- tlsConfigErrResults(ProbeTypeQUIC, componentUDPConnectivity, - regionTargets, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, quicTLSErr), actionQUICBlocked) + targetLabels, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, quicTLSErr), actionQUICBlocked) return } - quicCh <- probeAllRegions(ctx, ProbeTypeQUIC, componentUDPConnectivity, - perRegionAddrs, regionTargets, + quicCh <- probeAllTargets(ctx, ProbeTypeQUIC, componentUDPConnectivity, + filteredAddrs, targetLabels, func(addr *allregions.EdgeAddr) CheckResult { return probeQUIC(ctx, quicTLSConfig, runDialers.QUICDialer, addr, log) }) @@ -113,11 +137,11 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru if http2TLSErr != nil { log.Warn().Err(http2TLSErr).Msg("Failed to build HTTP/2 probe TLS config") http2Ch <- tlsConfigErrResults(ProbeTypeHTTP2, componentTCPConnectivity, - regionTargets, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, http2TLSErr), actionHTTP2Blocked) + targetLabels, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, http2TLSErr), actionHTTP2Blocked) return } - http2Ch <- probeAllRegions(ctx, ProbeTypeHTTP2, componentTCPConnectivity, - perRegionAddrs, regionTargets, + http2Ch <- probeAllTargets(ctx, ProbeTypeHTTP2, componentTCPConnectivity, + filteredAddrs, targetLabels, func(addr *allregions.EdgeAddr) CheckResult { return probeHTTP2(ctx, http2TLSConfig, runDialers.TCPDialer, addr) }) @@ -136,11 +160,11 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru } } -// tlsConfigErrResults returns one Fail CheckResult per region target, used when +// tlsConfigErrResults returns one Fail CheckResult per target, used when // TLS config construction fails before any dial is attempted. -func tlsConfigErrResults(probeType ProbeType, component string, regionTargets []string, details, action string) []CheckResult { - results := make([]CheckResult, len(regionTargets)) - for i, target := range regionTargets { +func tlsConfigErrResults(probeType ProbeType, component string, targets []string, details, action string) []CheckResult { + results := make([]CheckResult, len(targets)) + for i, target := range targets { results[i] = CheckResult{ Type: probeType, Component: component, @@ -153,47 +177,32 @@ func tlsConfigErrResults(probeType ProbeType, component string, regionTargets [] return results } -func runDNSProbe(ctx context.Context, resolver DNSResolver, region string) ([][]*allregions.EdgeAddr, []CheckResult) { - var addrGroups [][]*allregions.EdgeAddr - var dnsResults []CheckResult - withRetry(ctx, maxRetries, func() bool { - addrGroups, dnsResults = probeDNS(resolver, region) - for _, r := range dnsResults { - if r.ProbeStatus == Fail { - return false - } - } - return len(dnsResults) > 0 - }) - return addrGroups, dnsResults -} - -// probeAllRegions probes each region sequentially and returns one CheckResult -// per region. Within each region, all available addresses (V4 and/or V6) are -// tried and the best result is kept. -func probeAllRegions( +// probeAllTargets probes each target group sequentially and returns one +// CheckResult per group. Within each group, all available addresses (V4 and/or +// V6) are tried and the best result is kept. +func probeAllTargets( ctx context.Context, probeType ProbeType, component string, - perRegionAddrs [][]*allregions.EdgeAddr, - regionTargets []string, + perGroupAddrs [][]*allregions.EdgeAddr, + targets []string, probeFn func(*allregions.EdgeAddr) CheckResult, ) []CheckResult { - results := make([]CheckResult, len(perRegionAddrs)) - for i, addrs := range perRegionAddrs { - results[i] = probeRegion(ctx, probeType, component, regionTargets[i], addrs, probeFn) + results := make([]CheckResult, len(perGroupAddrs)) + for i, addrs := range perGroupAddrs { + results[i] = probeTarget(ctx, probeType, component, targets[i], addrs, probeFn) } return results } -// probeRegion probes all addresses for a single region (typically one V4 and/or -// one V6) and returns the best result. Any address passing means the region is -// reachable, so Pass beats Fail within a region. -func probeRegion( +// probeTarget probes all addresses for a single target group (typically one V4 +// and/or one V6) and returns the best result. Any address passing means the +// target is reachable, so Pass beats Fail within a group. +func probeTarget( ctx context.Context, probeType ProbeType, component string, - regionTarget string, + target string, addrs []*allregions.EdgeAddr, probeFn func(*allregions.EdgeAddr) CheckResult, ) CheckResult { @@ -201,7 +210,7 @@ func probeRegion( return CheckResult{ Type: probeType, Component: component, - Target: regionTarget, + Target: target, ProbeStatus: Skip, Details: "No suitable address found for configured IP version", } @@ -213,7 +222,7 @@ func probeRegion( best = r } } - best.Target = regionTarget + best.Target = target return best } @@ -238,11 +247,11 @@ func probeWithRetry(ctx context.Context, addr *allregions.EdgeAddr, probeFn func return r } -// addrsByRegion returns the addresses to probe for each DNS-resolved region, -// preserving the per-region grouping. Each inner slice contains at most one V4 +// addrsByGroup returns the addresses to probe for each resolved target group, +// preserving the per-group structure. Each inner slice contains at most one V4 // and one V6 address (subject to ipVersion). -func addrsByRegion(addrGroups [][]*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) [][]*allregions.EdgeAddr { - perRegion := make([][]*allregions.EdgeAddr, 0, len(addrGroups)) +func addrsByGroup(addrGroups [][]*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) [][]*allregions.EdgeAddr { + perGroup := make([][]*allregions.EdgeAddr, 0, len(addrGroups)) for _, group := range addrGroups { v4, v6 := addrsByFamily(group, ipVersion) var addrs []*allregions.EdgeAddr @@ -252,27 +261,17 @@ func addrsByRegion(addrGroups [][]*allregions.EdgeAddr, ipVersion allregions.Con if v6 != nil { addrs = append(addrs, v6) } - perRegion = append(perRegion, addrs) + perGroup = append(perGroup, addrs) } - return perRegion + return perGroup } -// dnsTargets extracts the Target hostname from each DNS CheckResult so that -// transport probe rows reuse the same region hostnames. -func dnsTargets(dnsResults []CheckResult) []string { - targets := make([]string, len(dnsResults)) - for i, r := range dnsResults { - targets[i] = r.Target - } - return targets -} - -// skipResultsForRegions returns one skip CheckResult per DNS region, using each -// region's hostname as the Target so the output table row aligns with its DNS row. -func skipResultsForRegions(dnsResults []CheckResult, probeType ProbeType, component string) []CheckResult { - results := make([]CheckResult, len(dnsResults)) - for i, dns := range dnsResults { - results[i] = skipResult(probeType, component, dns.Target) +// skipResultsForTargets returns one skip CheckResult per entry in results, +// using each entry's Target label so the transport row aligns with its DNS row. +func skipResultsForTargets(targets []CheckResult, probeType ProbeType, component string) []CheckResult { + results := make([]CheckResult, len(targets)) + for i, t := range targets { + results[i] = skipResult(probeType, component, t.Target, detailsDNSPrerequisiteFailed) } return results } @@ -320,7 +319,7 @@ func suggestProtocol(quicResults, http2Results []CheckResult) *connection.Protoc } // withRetry calls fn up to 1+maxAttempts times, stopping as soon as fn returns -// true. Between attempts it sleeps with exponential backoff bounded by +// true. Between attempts, it sleeps with exponential backoff bounded by // maxRetryDelay, and stops early if ctx is done. func withRetry(ctx context.Context, maxAttempts int, fn func() bool) { b := backoff.NewWithoutJitter(maxRetryDelay, retryBaseDelay) diff --git a/prechecks/checker_test.go b/prechecks/checker_test.go index 7d19627e..36686e0b 100644 --- a/prechecks/checker_test.go +++ b/prechecks/checker_test.go @@ -420,20 +420,101 @@ func TestRun_BothFamiliesProbed(t *testing.T) { assert.Equal(t, connection.QUIC, *report.SuggestedProtocol) } -// TestRun_IPv4OnlySkipsV6 verifies that when IPv4Only is configured only V4 -// addresses are probed (2 regions × 1 V4 = 2 dials per transport). -func TestRun_IPv4OnlySkipsV6(t *testing.T) { +// TestRun_IPVersionRestriction verifies that when a single IP family is +// configured, only that family is probed (2 regions × 1 addr = 2 dials per +// transport) and the excluded family is never dialled. +func TestRun_IPVersionRestriction(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ipVersion allregions.ConfigIPVersion + }{ + {"IPv4Only skips V6", allregions.IPv4Only}, + {"IPv6Only skips V4", allregions.IPv6Only}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil) + // 2 regions × 1 addr per restricted family = 2 dials each. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(2) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).Times(2) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: tt.ipVersion}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) + }) + } +} + +// TestRun_EdgeAddrs_SingleAddr verifies that a single --edge addr bypasses DNS +// probing. The report contains one DNS Skip row, transport rows labeled with +// the raw addr string, and the Management API row. +func TestRun_EdgeAddrs_SingleAddr(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - dns := mocks.NewMockDNSResolver(ctrl) tcp := mocks.NewMockTCPDialer(ctrl) quicD := mocks.NewMockQUICDialer(ctrl) mgmt := mocks.NewMockManagementDialer(ctrl) - dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil) - // IPv4Only: only V4 addresses are probed → 2 regions × 1 V4 = 2 calls each. - // V6 addresses must never be dialed. + // DNS resolver must NOT be called when EdgeAddrs is set. + dns := mocks.NewMockDNSResolver(ctrl) + dns.EXPECT().Resolve(gomock.Any()).Times(0) + + // One addr resolves to one group → one dial per transport. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(1) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).Times(1) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + cfg := Config{ + EdgeAddrs: []string{"127.0.0.1:7844"}, + Timeout: 2 * time.Second, + IPVersion: allregions.Auto, + } + report := Run(t.Context(), emptyCert, cfg, nopLogger(), + RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // 1 DNS Skip + 1 QUIC + 1 HTTP2 + 1 API = 4 results. + requireStatuses(t, report, Pass, Pass, Pass, Pass) + assert.Equal(t, ProbeTypeDNS, report.Results[0].Type, "first row must be DNS skip") + assert.Equal(t, "127.0.0.1:7844", report.Results[1].Target, "QUIC target must be the raw --edge addr") + assert.Equal(t, "127.0.0.1:7844", report.Results[2].Target, "HTTP2 target must be the raw --edge addr") + require.NotNil(t, report.SuggestedProtocol) + assert.Equal(t, connection.QUIC, *report.SuggestedProtocol) +} + +// TestRun_EdgeAddrs_MultipleAddrs verifies that multiple --edge addrs produce +// one transport row per addr, each labeled with its original addr string. +func TestRun_EdgeAddrs_MultipleAddrs(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns := mocks.NewMockDNSResolver(ctrl) + dns.EXPECT().Resolve(gomock.Any()).Times(0) + + // Two addrs → two groups → two dials per transport. tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nopConn{}, nil).Times(2) quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). @@ -441,35 +522,60 @@ func TestRun_IPv4OnlySkipsV6(t *testing.T) { mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). Return(nopConn{}, nil) - report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv4Only}, - nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + cfg := Config{ + EdgeAddrs: []string{"127.0.0.1:7844", "127.0.0.2:7844"}, + Timeout: 2 * time.Second, + IPVersion: allregions.Auto, + } + report := Run(t.Context(), emptyCert, cfg, nopLogger(), + RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + // 2 DNS Pass (one per addr) + 2 QUIC + 2 HTTP2 + 1 API = 7 results. requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) + assert.Equal(t, ProbeTypeDNS, report.Results[0].Type, "first row must be DNS skip addr1") + assert.Equal(t, "127.0.0.1:7844", report.Results[0].Target, "DNS skip addr1 label") + assert.Equal(t, ProbeTypeDNS, report.Results[1].Type, "second row must be DNS skip addr2") + assert.Equal(t, "127.0.0.2:7844", report.Results[1].Target, "DNS skip addr2 label") + assert.Equal(t, "127.0.0.1:7844", report.Results[2].Target, "QUIC addr1") + assert.Equal(t, "127.0.0.2:7844", report.Results[3].Target, "QUIC addr2") + assert.Equal(t, "127.0.0.1:7844", report.Results[4].Target, "HTTP2 addr1") + assert.Equal(t, "127.0.0.2:7844", report.Results[5].Target, "HTTP2 addr2") } -// TestRun_IPv6OnlySkipsV4 verifies that when IPv6Only is configured only V6 -// addresses are probed (2 regions × 1 V6 = 2 dials per transport). -func TestRun_IPv6OnlySkipsV4(t *testing.T) { +// TestRun_EdgeAddrs_UnresolvableAddr verifies that when all --edge addrs fail +// to resolve, the DNS resolver is not called and transport rows are skipped, +// mirroring the DNS skip row. +func TestRun_EdgeAddrs_UnresolvableAddr(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - dns := mocks.NewMockDNSResolver(ctrl) tcp := mocks.NewMockTCPDialer(ctrl) quicD := mocks.NewMockQUICDialer(ctrl) mgmt := mocks.NewMockManagementDialer(ctrl) - dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil) - // IPv6Only: only V6 addresses are probed → 2 regions × 1 V6 = 2 calls each. - // V4 addresses must never be dialled. - tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(nopConn{}, nil).Times(2) - quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(&fakeQUICConn{}, nil).Times(2) + dns := mocks.NewMockDNSResolver(ctrl) + dns.EXPECT().Resolve(gomock.Any()).Times(0) + + // Unresolvable addr → no groups → no transport dials. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). Return(nopConn{}, nil) - report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv6Only}, - nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + cfg := Config{ + EdgeAddrs: []string{"not-a-valid-addr"}, + Timeout: 2 * time.Second, + IPVersion: allregions.Auto, + } + report := Run(t.Context(), emptyCert, cfg, nopLogger(), + RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) - requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) + // 1 DNS Fail + 1 QUIC Skip + 1 HTTP2 Skip + 1 API = 4 results. + requireStatuses(t, report, Fail, Skip, Skip, Pass) + assert.Equal(t, ProbeTypeDNS, report.Results[0].Type) + assert.Equal(t, "not-a-valid-addr", report.Results[0].Target) + assert.Equal(t, ProbeTypeQUIC, report.Results[1].Type) + assert.Equal(t, ProbeTypeHTTP2, report.Results[2].Type) + assert.Nil(t, report.SuggestedProtocol) + assert.True(t, report.hasHardFail()) } diff --git a/prechecks/probes.go b/prechecks/probes.go index b99e8216..8142618f 100644 --- a/prechecks/probes.go +++ b/prechecks/probes.go @@ -40,6 +40,7 @@ const ( targetPortQUIC = "Port 7844 (QUIC)" targetPortHTTP2 = "Port 7844 (HTTP/2)" targetAPI = "api.cloudflare.com:443" + noDNSTarget = "No DNS target (Using edge flag)" // Details messages for CheckResult. detailsNoAddressesReturned = "No addresses returned" @@ -72,20 +73,6 @@ func (r *EdgeDNSResolver) Resolve(region string) ([][]*allregions.EdgeAddr, erro return allregions.EdgeDiscovery(r.Log, allregions.RegionalServiceName(region)) } -// StaticEdgeDNSResolver implements DNSResolver for the --edge flag path. -type StaticEdgeDNSResolver struct { - Addrs []string - Log *zerolog.Logger -} - -func (r *StaticEdgeDNSResolver) Resolve(_ string) ([][]*allregions.EdgeAddr, error) { - resolved := allregions.ResolveAddrs(r.Addrs, r.Log) - if len(resolved) == 0 { - return nil, fmt.Errorf("failed to resolve any edge address") - } - return [][]*allregions.EdgeAddr{resolved}, nil -} - type EdgeTCPDialer struct{} func (d *EdgeTCPDialer) DialEdge( @@ -141,13 +128,15 @@ func probeTLSConfig(caCert string, p connection.Protocol) (*tls.Config, error) { } // probeDNS resolves edge addresses for the given region via the supplied -// DNSResolver and returns a CheckResult for each region discovered. If -// resolution fails for all regions, every result will carry StatusFail. +// DNSResolver and returns one ResolvedTarget per discovered region. If +// resolution fails entirely, every ResolvedTarget will carry a Fail DNSResult +// and nil Addrs. func probeDNS( resolver DNSResolver, region string, -) ([][]*allregions.EdgeAddr, []CheckResult) { +) []ResolvedTarget { region1Target, region2Target := regionTargets(region) + targets := []string{region1Target, region2Target} addrGroups, err := resolver.Resolve(region) if err != nil || len(addrGroups) == 0 { @@ -155,28 +144,31 @@ func probeDNS( if err != nil { detail = err.Error() } - return nil, []CheckResult{ - newDNSCheckResult(region1Target, Fail, detail, fmt.Sprintf(actionDNSFail, region1Target, region1Target)), - newDNSCheckResult(region2Target, Fail, detail, fmt.Sprintf(actionDNSFail, region2Target, region2Target)), + return []ResolvedTarget{ + {DNSResult: newDNSCheckResult(region1Target, Fail, detail, fmt.Sprintf(actionDNSFail, region1Target, region1Target))}, + {DNSResult: newDNSCheckResult(region2Target, Fail, detail, fmt.Sprintf(actionDNSFail, region2Target, region2Target))}, } } - targets := []string{region1Target, region2Target} - - results := make([]CheckResult, 0, len(addrGroups)) - for i, group := range addrGroups { - target := fmt.Sprintf("region%d.v2.argotunnel.com", i+1) - if i < len(targets) { - target = targets[i] + resolved := make([]ResolvedTarget, 0, len(addrGroups)) + for i, target := range targets { + if i >= len(addrGroups) { + break } + group := addrGroups[i] if len(group) == 0 { - results = append(results, newDNSCheckResult(target, Fail, detailsNoAddressesReturned, fmt.Sprintf(actionDNSFail, target, target))) + resolved = append(resolved, ResolvedTarget{ + DNSResult: newDNSCheckResult(target, Fail, detailsNoAddressesReturned, fmt.Sprintf(actionDNSFail, target, target)), + }) } else { - results = append(results, newDNSCheckResult(target, Pass, detailsResolvedSuccessfully, "")) + resolved = append(resolved, ResolvedTarget{ + Addrs: group, + DNSResult: newDNSCheckResult(target, Pass, detailsResolvedSuccessfully, ""), + }) } } - return addrGroups, results + return resolved } // probeQUIC performs a QUIC handshake to a single edge address and returns a @@ -296,13 +288,13 @@ func probeManagementAPI(ctx context.Context, dialer ManagementDialer) CheckResul } } -func skipResult(probeType ProbeType, component, target string) CheckResult { +func skipResult(probeType ProbeType, component, target string, details string) CheckResult { return CheckResult{ Type: probeType, Component: component, Target: target, ProbeStatus: Skip, - Details: detailsDNSPrerequisiteFailed, + Details: details, } } @@ -345,3 +337,39 @@ func addrsByFamily(group []*allregions.EdgeAddr, ipVersion allregions.ConfigIPVe } return } + +// runDNSProbe runs probeDNS with retry and returns []ResolvedTarget. +func runDNSProbe(ctx context.Context, resolver DNSResolver, region string) []ResolvedTarget { + var targets []ResolvedTarget + withRetry(ctx, maxRetries, func() bool { + targets = probeDNS(resolver, region) + for _, t := range targets { + if t.DNSResult.ProbeStatus == Fail { + return false + } + } + return len(targets) > 0 + }) + return targets +} + +// resolveStaticEdge resolves each --edge addr individually, returning one +// ResolvedTarget per addr. Unresolvable addrs produce a Fail ResolvedTarget +// with nil Addrs so the report shows which addresses could not be reached. +func resolveStaticEdge(addrs []string, log *zerolog.Logger) []ResolvedTarget { + targets := make([]ResolvedTarget, 0, len(addrs)) + for _, addr := range addrs { + resolved := allregions.ResolveAddrs([]string{addr}, log) + if len(resolved) > 0 { + targets = append(targets, ResolvedTarget{ + Addrs: resolved, + DNSResult: newDNSCheckResult(addr, Pass, detailsResolvedSuccessfully, ""), + }) + } else { + targets = append(targets, ResolvedTarget{ + DNSResult: newDNSCheckResult(addr, Fail, detailsNoAddressesReturned, fmt.Sprintf(actionDNSFail, addr, addr)), + }) + } + } + return targets +} diff --git a/prechecks/probes_test.go b/prechecks/probes_test.go index 9c6b65d5..7a350f91 100644 --- a/prechecks/probes_test.go +++ b/prechecks/probes_test.go @@ -117,15 +117,14 @@ func TestProbeDNS_Success(t *testing.T) { resolver := mocks.NewMockDNSResolver(ctrl) resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr, v6Addr}}, nil) - addrs, results := probeDNS(resolver, "") + targets := probeDNS(resolver, "") - require.NotNil(t, addrs) - require.Len(t, results, 1) - assert.Len(t, addrs, 1) - assert.Equal(t, ProbeTypeDNS, results[0].Type) - assert.Equal(t, testRegion1Global, results[0].Target) - assert.Equal(t, Pass, results[0].ProbeStatus) - assert.Equal(t, detailsResolvedSuccessfully, results[0].Details) + require.Len(t, targets, 1) + assert.NotEmpty(t, targets[0].Addrs) + assert.Equal(t, ProbeTypeDNS, targets[0].DNSResult.Type) + assert.Equal(t, testRegion1Global, targets[0].DNSResult.Target) + assert.Equal(t, Pass, targets[0].DNSResult.ProbeStatus) + assert.Equal(t, detailsResolvedSuccessfully, targets[0].DNSResult.Details) } func TestProbeDNS_MultipleRegions(t *testing.T) { @@ -139,17 +138,17 @@ func TestProbeDNS_MultipleRegions(t *testing.T) { resolver := mocks.NewMockDNSResolver(ctrl) resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr1}, {v4Addr2}}, nil) - addrs, results := probeDNS(resolver, "") + targets := probeDNS(resolver, "") - require.NotNil(t, addrs) - require.Len(t, results, 2) - assert.Len(t, addrs, 2) + require.Len(t, targets, 2) - assert.Equal(t, testRegion1Global, results[0].Target) - assert.Equal(t, Pass, results[0].ProbeStatus) + assert.Equal(t, testRegion1Global, targets[0].DNSResult.Target) + assert.Equal(t, Pass, targets[0].DNSResult.ProbeStatus) + assert.NotEmpty(t, targets[0].Addrs) - assert.Equal(t, testRegion2Global, results[1].Target) - assert.Equal(t, Pass, results[1].ProbeStatus) + assert.Equal(t, testRegion2Global, targets[1].DNSResult.Target) + assert.Equal(t, Pass, targets[1].DNSResult.ProbeStatus) + assert.NotEmpty(t, targets[1].Addrs) } func TestProbeDNS_ResolverError(t *testing.T) { @@ -160,17 +159,16 @@ func TestProbeDNS_ResolverError(t *testing.T) { resolver := mocks.NewMockDNSResolver(ctrl) resolver.EXPECT().Resolve("").Return(nil, errors.New("DNS lookup failed")) - addrs, results := probeDNS(resolver, "") + targets := probeDNS(resolver, "") - assert.Nil(t, addrs) - require.Len(t, results, 2) - - assert.Equal(t, Fail, results[0].ProbeStatus) - assert.Equal(t, "DNS lookup failed", results[0].Details) - assert.Contains(t, results[0].Action, testRegion1Global) - assert.Contains(t, results[1].Action, testRegion2Global) - - assert.Equal(t, Fail, results[1].ProbeStatus) + require.Len(t, targets, 2) + assert.Empty(t, targets[0].Addrs) + assert.Equal(t, Fail, targets[0].DNSResult.ProbeStatus) + assert.Equal(t, "DNS lookup failed", targets[0].DNSResult.Details) + assert.Contains(t, targets[0].DNSResult.Action, testRegion1Global) + assert.Empty(t, targets[1].Addrs) + assert.Equal(t, Fail, targets[1].DNSResult.ProbeStatus) + assert.Contains(t, targets[1].DNSResult.Action, testRegion2Global) } func TestProbeDNS_EmptyResults(t *testing.T) { @@ -181,12 +179,12 @@ func TestProbeDNS_EmptyResults(t *testing.T) { resolver := mocks.NewMockDNSResolver(ctrl) resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{}, nil) - addrs, results := probeDNS(resolver, "") + targets := probeDNS(resolver, "") - assert.Nil(t, addrs) - require.Len(t, results, 2) - assert.Equal(t, Fail, results[0].ProbeStatus) - assert.Equal(t, "No addresses returned", results[0].Details) + require.Len(t, targets, 2) + assert.Empty(t, targets[0].Addrs) + assert.Equal(t, Fail, targets[0].DNSResult.ProbeStatus) + assert.Equal(t, detailsNoAddressesReturned, targets[0].DNSResult.Details) } func TestProbeDNS_EmptyGroup(t *testing.T) { @@ -197,12 +195,12 @@ func TestProbeDNS_EmptyGroup(t *testing.T) { resolver := mocks.NewMockDNSResolver(ctrl) resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{}}, nil) - addrs, results := probeDNS(resolver, "") + targets := probeDNS(resolver, "") - require.NotNil(t, addrs) - require.Len(t, results, 1) - assert.Equal(t, Fail, results[0].ProbeStatus) - assert.Equal(t, "No addresses returned", results[0].Details) + require.Len(t, targets, 1) + assert.Empty(t, targets[0].Addrs) + assert.Equal(t, Fail, targets[0].DNSResult.ProbeStatus) + assert.Equal(t, detailsNoAddressesReturned, targets[0].DNSResult.Details) } func TestProbeDNS_RegionFlag(t *testing.T) { @@ -214,10 +212,10 @@ func TestProbeDNS_RegionFlag(t *testing.T) { resolver := mocks.NewMockDNSResolver(ctrl) resolver.EXPECT().Resolve("us").Return([][]*allregions.EdgeAddr{{v4Addr}}, nil) - _, results := probeDNS(resolver, "us") + targets := probeDNS(resolver, "us") - require.Len(t, results, 1) - assert.Equal(t, testRegion1US, results[0].Target) + require.Len(t, targets, 1) + assert.Equal(t, testRegion1US, targets[0].DNSResult.Target) } // probeQUIC tests. @@ -373,7 +371,7 @@ func TestProbeManagementAPI_DialError(t *testing.T) { func TestSkipResult(t *testing.T) { t.Parallel() - result := skipResult(ProbeTypeQUIC, "UDP Connectivity", "Port 7844 (QUIC)") + result := skipResult(ProbeTypeQUIC, "UDP Connectivity", "Port 7844 (QUIC)", detailsDNSPrerequisiteFailed) assert.Equal(t, ProbeTypeQUIC, result.Type) assert.Equal(t, "UDP Connectivity", result.Component) @@ -537,3 +535,61 @@ func TestProbeHTTP2_IPv6Address(t *testing.T) { assert.Equal(t, Pass, result.ProbeStatus) } + +// resolveStaticEdge tests. + +// TestResolveStaticEdge_SingleAddr verifies that a single resolvable --edge +// addr produces one group labeled with the original addr string. +func TestResolveStaticEdge_SingleAddr(t *testing.T) { + t.Parallel() + logger := zerolog.Nop() + targets := resolveStaticEdge([]string{"127.0.0.1:7844"}, &logger) + require.Len(t, targets, 1) + assert.Equal(t, "127.0.0.1:7844", targets[0].DNSResult.Target) + assert.Equal(t, Pass, targets[0].DNSResult.ProbeStatus) + assert.NotEmpty(t, targets[0].Addrs) +} + +// TestResolveStaticEdge_MultipleAddrs verifies that multiple --edge addrs each +// produce their own ResolvedTarget, preserving per-addr structure and label order. +func TestResolveStaticEdge_MultipleAddrs(t *testing.T) { + t.Parallel() + logger := zerolog.Nop() + targets := resolveStaticEdge([]string{"127.0.0.1:7844", "127.0.0.2:7844"}, &logger) + require.Len(t, targets, 2) + assert.Equal(t, "127.0.0.1:7844", targets[0].DNSResult.Target) + assert.Equal(t, "127.0.0.2:7844", targets[1].DNSResult.Target) +} + +// TestResolveStaticEdge_InvalidAddr verifies that an unresolvable addr is +// silently skipped and does not appear in the output. +func TestResolveStaticEdge_InvalidAddr(t *testing.T) { + t.Parallel() + logger := zerolog.Nop() + // "not-a-valid-addr" has no port — ResolveTCPAddr will fail. + targets := resolveStaticEdge([]string{"not-a-valid-addr"}, &logger) + require.Len(t, targets, 1) + assert.Equal(t, "not-a-valid-addr", targets[0].DNSResult.Target) + assert.Equal(t, Fail, targets[0].DNSResult.ProbeStatus) + assert.Equal(t, detailsNoAddressesReturned, targets[0].DNSResult.Details) + assert.Empty(t, targets[0].Addrs) +} + +// TestResolveStaticEdge_PartiallyValid verifies that a mix of valid and invalid +// addrs produces one ResolvedTarget per addr — valid ones with Addrs and a Skip +// DNSResult, invalid ones with nil Addrs and a Fail DNSResult. +func TestResolveStaticEdge_PartiallyValid(t *testing.T) { + t.Parallel() + logger := zerolog.Nop() + targets := resolveStaticEdge([]string{"127.0.0.1:7844", "not-a-valid-addr", "127.0.0.2:7844"}, &logger) + require.Len(t, targets, 3) + assert.Equal(t, "127.0.0.1:7844", targets[0].DNSResult.Target) + assert.Equal(t, Pass, targets[0].DNSResult.ProbeStatus) + assert.NotEmpty(t, targets[0].Addrs) + assert.Equal(t, "not-a-valid-addr", targets[1].DNSResult.Target) + assert.Equal(t, Fail, targets[1].DNSResult.ProbeStatus) + assert.Empty(t, targets[1].Addrs) + assert.Equal(t, "127.0.0.2:7844", targets[2].DNSResult.Target) + assert.Equal(t, Pass, targets[2].DNSResult.ProbeStatus) + assert.NotEmpty(t, targets[2].Addrs) +} diff --git a/prechecks/types.go b/prechecks/types.go index 7b74a34d..54856b50 100644 --- a/prechecks/types.go +++ b/prechecks/types.go @@ -74,6 +74,19 @@ type CheckResult struct { Action string } +// ResolvedTarget bundles a resolved edge target's addresses with the DNS +// CheckResult that describes it. This keeps addr groups and their report rows +// together as a single unit, avoiding parallel-slice synchronization. +type ResolvedTarget struct { + // Addrs holds the resolved edge addresses for this target. May be empty + // when DNS resolution succeeded structurally but returned no IPs. + Addrs []*allregions.EdgeAddr + + // DNSResult is the CheckResult representing DNS resolution for this target. + // Its Target field is the human-readable label used across all probe rows. + DNSResult CheckResult +} + // Report aggregates all CheckResults produced by a single Run() invocation. // Pre-checks run in parallel with tunnel initialization and are purely // diagnostic: the Report is displayed to the user but never gates startup. @@ -107,4 +120,10 @@ type Config struct { // checks. It mirrors the --edge-ip-version CLI flag so that the pre-check // exercises the same code paths the tunnel itself will use. IPVersion allregions.ConfigIPVersion + + // EdgeAddrs, when non-empty, contains the --edge flag values (explicit + // edge addresses). When set, DNS probing is skipped entirely — there are + // no SRV records to validate — and transport probes target each addr + // individually, labeled with the original addr string. + EdgeAddrs []string }