Files
cloudflared/prechecks/checker.go
T
Miguel da Costa Martins Marcelino 3315fa6e0f TUN-10630: Fix precheck protocol override
As it stands, cloudflared prechecks are not taking the `protocol` flag into consideration and is instead falling back to the default protocol, which is QUIC. Prechecks should report the protocol cloudflared will use, not the default protocol.
2026-06-18 10:56:53 +00:00

384 lines
13 KiB
Go

package prechecks
import (
"context"
"fmt"
"slices"
"time"
"github.com/cloudflare/backoff"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
)
const (
defaultTimeout = 10 * time.Second
maxRetries = 2
retryBaseDelay = 1 * time.Second
maxRetryDelay = 16 * time.Second
)
// RunDialers holds the injectable dependencies for Run(). Production callers build
// this with real implementations; tests supply mocks.
type RunDialers struct {
DNSResolver DNSResolver
TCPDialer TCPDialer
QUICDialer QUICDialer
ManagementDialer ManagementDialer
}
// TransportResults holds the per-target results for each transport probe type.
// Each slice has one entry per resolved target group, in the same order as the
// target labels slice.
type TransportResults struct {
QUIC []CheckResult // one per target group
HTTP2 []CheckResult // one per target group
ManagementAPI CheckResult // single target, no groups
}
// Collect returns all results as a slice in a consistent order for reporting:
// all QUIC rows first (one per target), then all HTTP2 rows, then Management API.
func (tr TransportResults) Collect() []CheckResult {
results := make([]CheckResult, 0, len(tr.QUIC)+len(tr.HTTP2)+1)
results = append(results, tr.QUIC...)
results = append(results, tr.HTTP2...)
results = append(results, tr.ManagementAPI)
return results
}
// Run executes the following connectivity pre-checks:
//
// 1. Edge address resolution — either DNS-based SRV discovery (normal path)
// or direct resolution of --edge addresses (static path). The static path
// skips DNS probe rows entirely since there are no SRV records to validate.
// 2. QUIC, HTTP/2, and Management API probes run concurrently against the
// resolved addresses.
//
// Each failed probe is retried up to maxRetries times with exponential backoff.
// The suite is bounded by cfg.Timeout (defaultTimeout if zero).
func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, runDialers RunDialers) Report {
runID := uuid.New()
if cfg.Timeout <= 0 {
cfg.Timeout = defaultTimeout
}
ctx, cancel := context.WithTimeout(ctx, cfg.Timeout)
defer cancel()
// Build TLS configs once per protocol.
quicTLSConfig, quicTLSErr := probeTLSConfig(caCert, connection.QUIC)
http2TLSConfig, http2TLSErr := probeTLSConfig(caCert, connection.HTTP2)
// 1) Resolve edge addresses. Each ResolvedTarget bundles its addr group
// with the DNS CheckResult that labels it, keeping the two in sync.
var resolvedTargets []ResolvedTarget
if len(cfg.EdgeAddrs) > 0 {
// Static path: explicit --edge addresses, one ResolvedTarget per addr.
resolvedTargets = resolveStaticEdge(cfg.EdgeAddrs, log)
} else {
// Normal path: SRV-based discovery; DNS rows carry Pass or Fail status.
resolvedTargets = runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region)
}
// Extract parallel slices for the transport probe layer.
// nolint:prealloc // False positive. The linter is confused by the append used when producing Report.Results
dnsResults := make([]CheckResult, len(resolvedTargets))
perGroupAddrs := make([][]*allregions.EdgeAddr, len(resolvedTargets))
targetLabels := make([]string, len(resolvedTargets))
for i, rt := range resolvedTargets {
dnsResults[i] = rt.DNSResult
perGroupAddrs[i] = rt.Addrs
targetLabels[i] = rt.DNSResult.Target
}
// dnsOK is true when at least one target has addresses to probe.
dnsOK := slices.ContainsFunc(resolvedTargets, func(r ResolvedTarget) bool {
return len(r.Addrs) > 0
})
// 2) Run transport probes concurrently. Each probe type gets its own
// buffered channel — one send, one receive, no routing required.
var results TransportResults
mgmtCh := make(chan CheckResult)
go func() {
mgmtCh <- probeManagementAPIWithRetry(ctx, runDialers.ManagementDialer)
}()
if !dnsOK {
// No addresses available: emit one skip row per target so the table
// stays consistent with the DNS rows above.
results.QUIC = skipResultsForTargets(dnsResults, ProbeTypeQUIC, componentUDPConnectivity)
results.HTTP2 = skipResultsForTargets(dnsResults, ProbeTypeHTTP2, componentTCPConnectivity)
} else {
filteredAddrs := addrsByGroup(perGroupAddrs, cfg.IPVersion)
quicCh := make(chan []CheckResult, 1)
http2Ch := make(chan []CheckResult, 1)
go func() {
if quicTLSErr != nil {
log.Warn().Err(quicTLSErr).Msg("Failed to build QUIC probe TLS config")
quicCh <- tlsConfigErrResults(ProbeTypeQUIC, componentUDPConnectivity,
targetLabels, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, quicTLSErr), actionQUICBlocked)
return
}
quicCh <- probeAllTargets(ctx, ProbeTypeQUIC, componentUDPConnectivity,
filteredAddrs, targetLabels,
func(addr *allregions.EdgeAddr) CheckResult {
return probeQUIC(ctx, quicTLSConfig, runDialers.QUICDialer, addr, log)
})
}()
go func() {
if http2TLSErr != nil {
log.Warn().Err(http2TLSErr).Msg("Failed to build HTTP/2 probe TLS config")
http2Ch <- tlsConfigErrResults(ProbeTypeHTTP2, componentTCPConnectivity,
targetLabels, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, http2TLSErr), actionHTTP2Blocked)
return
}
http2Ch <- probeAllTargets(ctx, ProbeTypeHTTP2, componentTCPConnectivity,
filteredAddrs, targetLabels,
func(addr *allregions.EdgeAddr) CheckResult {
return probeHTTP2(ctx, http2TLSConfig, runDialers.TCPDialer, addr)
})
}()
results.QUIC = <-quicCh
results.HTTP2 = <-http2Ch
}
results.ManagementAPI = <-mgmtCh
return Report{
RunID: runID,
Results: append(dnsResults, results.Collect()...),
SuggestedProtocol: suggestProtocol(results.QUIC, results.HTTP2, cfg.ProtocolOverride),
}
}
// tlsConfigErrResults returns one Fail CheckResult per target, used when
// TLS config construction fails before any dial is attempted.
func tlsConfigErrResults(probeType ProbeType, component string, targets []string, details, action string) []CheckResult {
results := make([]CheckResult, len(targets))
for i, target := range targets {
results[i] = CheckResult{
Type: probeType,
Component: component,
Target: target,
ProbeStatus: Fail,
Details: details,
Action: action,
}
}
return results
}
// probeAllTargets probes each target group sequentially and returns one
// CheckResult per group. Within each group, all available addresses (V4 and/or
// V6) are tried and the best result is kept.
func probeAllTargets(
ctx context.Context,
probeType ProbeType,
component string,
perGroupAddrs [][]*allregions.EdgeAddr,
targets []string,
probeFn func(*allregions.EdgeAddr) CheckResult,
) []CheckResult {
results := make([]CheckResult, len(perGroupAddrs))
for i, addrs := range perGroupAddrs {
results[i] = probeTarget(ctx, probeType, component, targets[i], addrs, probeFn)
}
return results
}
// probeTarget probes all addresses for a single target group (typically one V4
// and/or one V6) and returns the best result. Any address passing means the
// target is reachable, so Pass beats Fail within a group.
func probeTarget(
ctx context.Context,
probeType ProbeType,
component string,
target string,
addrs []*allregions.EdgeAddr,
probeFn func(*allregions.EdgeAddr) CheckResult,
) CheckResult {
if len(addrs) == 0 {
return CheckResult{
Type: probeType,
Component: component,
Target: target,
ProbeStatus: Skip,
Details: "No suitable address found for configured IP version",
}
}
best := probeWithRetry(ctx, addrs[0], probeFn)
for _, addr := range addrs[1:] {
if r := probeWithRetry(ctx, addr, probeFn); r.ProbeStatus == Pass {
best = r
}
}
best.Target = target
return best
}
// probeManagementAPIWithRetry runs the Cloudflare API reachability probe with retry.
func probeManagementAPIWithRetry(ctx context.Context, dialer ManagementDialer) CheckResult {
var r CheckResult
withRetry(ctx, maxRetries, func() bool {
r = probeManagementAPI(ctx, dialer)
return r.ProbeStatus == Pass
})
return r
}
// probeWithRetry calls probeFn on addr with exponential-backoff retry up to
// maxRetries times, stopping as soon as the probe passes.
func probeWithRetry(ctx context.Context, addr *allregions.EdgeAddr, probeFn func(*allregions.EdgeAddr) CheckResult) CheckResult {
var r CheckResult
withRetry(ctx, maxRetries, func() bool {
r = probeFn(addr)
return r.ProbeStatus == Pass
})
return r
}
// addrsByGroup returns the addresses to probe for each resolved target group,
// preserving the per-group structure. Each inner slice contains at most one V4
// and one V6 address (subject to ipVersion).
func addrsByGroup(addrGroups [][]*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) [][]*allregions.EdgeAddr {
perGroup := make([][]*allregions.EdgeAddr, 0, len(addrGroups))
for _, group := range addrGroups {
v4, v6 := addrsByFamily(group, ipVersion)
var addrs []*allregions.EdgeAddr
if v4 != nil {
addrs = append(addrs, v4)
}
if v6 != nil {
addrs = append(addrs, v6)
}
perGroup = append(perGroup, addrs)
}
return perGroup
}
// skipResultsForTargets returns one skip CheckResult per entry in results,
// using each entry's Target label so the transport row aligns with its DNS row.
func skipResultsForTargets(targets []CheckResult, probeType ProbeType, component string) []CheckResult {
results := make([]CheckResult, len(targets))
for i, t := range targets {
results[i] = skipResult(probeType, component, t.Target, detailsDNSPrerequisiteFailed)
}
return results
}
// worstStatus returns the most severe Status across a slice of CheckResults.
// Fail > Pass > Skip. Used to determine whether a transport type as a whole
// should be considered failed (any region failing = transport fails).
func worstStatus(results []CheckResult) Status {
worst := Skip
for _, r := range results {
if severity(r.ProbeStatus) > severity(worst) {
worst = r.ProbeStatus
}
}
return worst
}
// severity maps a Status to a comparable integer so that worse outcomes rank higher.
func severity(s Status) int {
switch s {
case Fail:
return 2
case Pass:
return 1
case Skip:
return 0
default:
return 0
}
}
// parseProtocolOverride converts the raw --protocol flag string into a
// *connection.Protocol. It returns nil when the string is empty, "auto", or
// unrecognised, so the probe heuristic is used in those cases. "h2mux" is
// treated as HTTP/2 because both map to the same transport.
func parseProtocolOverride(flag string) *connection.Protocol {
switch flag {
case connection.QUIC.String():
p := connection.QUIC
return &p
case connection.HTTP2.String(), "h2mux":
p := connection.HTTP2
return &p
default:
// "auto", empty, or unknown — no override; let the heuristic decide.
return nil
}
}
// suggestProtocol determines the protocol to report in the pre-check summary.
//
// When the caller has explicitly overridden the protocol via --protocol, that
// choice is honoured when its transport probes produced evidence and did not
// fail.
//
// When there is no override (auto-selection), precedence is QUIC, HTTP/2,
// and nil. A protocol is only suggested if all probes pass.
//
// Any region failing means the transport is treated as failed (worst wins).
func suggestProtocol(quicResults, http2Results []CheckResult, overrideFlag string) *connection.Protocol {
if override := parseProtocolOverride(overrideFlag); override != nil {
switch *override {
case connection.QUIC:
// Only report QUIC as the suggested protocol if its probes did not
// all fail — if they did, fall through to the heuristic so the
// summary can report a usable fallback or nil.
if len(quicResults) > 0 && worstStatus(quicResults) != Fail {
return new(connection.QUIC)
}
case connection.HTTP2:
// Same logic for an explicit HTTP/2 override.
if len(http2Results) > 0 && worstStatus(http2Results) != Fail {
return new(connection.HTTP2)
}
}
}
if len(quicResults) > 0 && worstStatus(quicResults) == Pass {
quic := connection.QUIC
return &quic
}
if len(http2Results) > 0 && worstStatus(http2Results) == Pass {
http2 := connection.HTTP2
return &http2
}
return nil
}
// withRetry calls fn up to 1+maxAttempts times, stopping as soon as fn returns
// true. Between attempts, it sleeps with exponential backoff bounded by
// maxRetryDelay, and stops early if ctx is done.
func withRetry(ctx context.Context, maxAttempts int, fn func() bool) {
b := backoff.NewWithoutJitter(maxRetryDelay, retryBaseDelay)
for attempt := 0; attempt <= maxAttempts; attempt++ {
if fn() {
return
}
if attempt == maxAttempts {
break
}
timer := time.NewTimer(b.Duration())
select {
case <-ctx.Done():
timer.Stop()
return
case <-timer.C:
}
}
}