diff --git a/go.mod b/go.mod index 886269fa..4501e503 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/bytedance/sonic v1.12.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/ebitengine/purego v0.10.0 // indirect diff --git a/go.sum b/go.sum index 4d620f86..aa74875e 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chungthuang/quic-go v0.45.1-0.20250428085412-43229ad201fd h1:VdYI5zFQ2h1/qzoC6rhyPx479bkF8i177Qpg4Q2n1vk= github.com/chungthuang/quic-go v0.45.1-0.20250428085412-43229ad201fd/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ= +github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 h1:pRcxfaAlK0vR6nOeQs7eAEvjJzdGXl8+KaBlcvpQTyQ= +github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= diff --git a/prechecks/checker.go b/prechecks/checker.go new file mode 100644 index 00000000..24ac3915 --- /dev/null +++ b/prechecks/checker.go @@ -0,0 +1,308 @@ +package prechecks + +import ( + "context" + "slices" + "time" + + "github.com/cloudflare/backoff" + "github.com/google/uuid" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/edgediscovery/allregions" +) + +const ( + defaultTimeout = 10 * time.Second + maxRetries = 2 + retryBaseDelay = 1 * time.Second + maxRetryDelay = 16 * time.Second +) + +// RunDialers holds the injectable dependencies for Run(). Production callers build +// this with real implementations; tests supply mocks. +type RunDialers struct { + DNSResolver DNSResolver + TCPDialer TCPDialer + QUICDialer QUICDialer + ManagementDialer ManagementDialer +} + +// TransportResults holds the per-region results for each transport probe type. +// Each slice has one entry per DNS-resolved region, in the same order as dnsResults. +type TransportResults struct { + QUIC []CheckResult // one per region + HTTP2 []CheckResult // one per region + ManagementAPI CheckResult // single target, no regions +} + +// Collect returns all results as a slice in a consistent order for reporting: +// all QUIC rows first (one per region), then all HTTP2 rows, then Management API. +func (tr TransportResults) Collect() []CheckResult { + results := make([]CheckResult, 0, len(tr.QUIC)+len(tr.HTTP2)+1) + results = append(results, tr.QUIC...) + results = append(results, tr.HTTP2...) + results = append(results, tr.ManagementAPI) + return results +} + +// Run executes the following connectivity pre-checks: +// +// 1. DNS resolution (sequential – transport probes depend on its output). +// 2. QUIC, HTTP/2, and Management API probes run concurrently. +// +// Each failed probe is retried up to maxRetries times with exponential backoff. +// The suite is bounded by cfg.Timeout (defaultTimeout if zero). +func Run(ctx context.Context, cfg Config, log *zerolog.Logger, runDialers RunDialers) Report { + runID := uuid.New() + + if cfg.Timeout <= 0 { + cfg.Timeout = defaultTimeout + } + ctx, cancel := context.WithTimeout(ctx, cfg.Timeout) + defer cancel() + + // 1) DNS – must complete before transport probes know which addresses to dial. + addrGroups, dnsResults := runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region) + + dnsOK := !slices.ContainsFunc(dnsResults, func(r CheckResult) bool { + return r.ProbeStatus != Pass + }) + + // 2) Run probes concurrently. Each probe type gets its own buffered channel — + // one send, one receive, no routing or name-parsing required. + var results TransportResults + + mgmtCh := make(chan CheckResult) + go func() { + mgmtCh <- probeManagementAPIWithRetry(ctx, runDialers.ManagementDialer) + }() + + if !dnsOK { + // DNS failed: emit one skip row per region so the table stays consistent. + results.QUIC = skipResultsForRegions(dnsResults, ProbeTypeQUIC, "UDP Connectivity") + results.HTTP2 = skipResultsForRegions(dnsResults, ProbeTypeHTTP2, "TCP Connectivity") + } else { + perRegionAddrs := addrsByRegion(addrGroups, cfg.IPVersion) + regionTargets := dnsTargets(dnsResults) + + quicCh := make(chan []CheckResult, 1) + http2Ch := make(chan []CheckResult, 1) + + go func() { + quicCh <- probeAllRegions(ctx, ProbeTypeQUIC, "UDP Connectivity", + perRegionAddrs, regionTargets, + func(addr *allregions.EdgeAddr) CheckResult { + return probeQUIC(ctx, runDialers.QUICDialer, addr, log) + }) + }() + + go func() { + http2Ch <- probeAllRegions(ctx, ProbeTypeHTTP2, "TCP Connectivity", + perRegionAddrs, regionTargets, + func(addr *allregions.EdgeAddr) CheckResult { + return probeHTTP2(ctx, runDialers.TCPDialer, addr) + }) + }() + + results.QUIC = <-quicCh + results.HTTP2 = <-http2Ch + } + + results.ManagementAPI = <-mgmtCh + + return Report{ + RunID: runID, + Results: append(dnsResults, results.Collect()...), + SuggestedProtocol: suggestProtocol(results.QUIC, results.HTTP2), + } +} + +func runDNSProbe(ctx context.Context, resolver DNSResolver, region string) ([][]*allregions.EdgeAddr, []CheckResult) { + var addrGroups [][]*allregions.EdgeAddr + var dnsResults []CheckResult + withRetry(ctx, maxRetries, func() bool { + addrGroups, dnsResults = probeDNS(resolver, region) + for _, r := range dnsResults { + if r.ProbeStatus == Fail { + return false + } + } + return len(dnsResults) > 0 + }) + return addrGroups, dnsResults +} + +// probeAllRegions probes each region sequentially and returns one CheckResult +// per region. Within each region, all available addresses (V4 and/or V6) are +// tried and the best result is kept. +func probeAllRegions( + ctx context.Context, + probeType ProbeType, + component string, + perRegionAddrs [][]*allregions.EdgeAddr, + regionTargets []string, + probeFn func(*allregions.EdgeAddr) CheckResult, +) []CheckResult { + results := make([]CheckResult, len(perRegionAddrs)) + for i, addrs := range perRegionAddrs { + results[i] = probeRegion(ctx, probeType, component, regionTargets[i], addrs, probeFn) + } + return results +} + +// probeRegion probes all addresses for a single region (typically one V4 and/or +// one V6) and returns the best result. Any address passing means the region is +// reachable, so Pass beats Fail within a region. +func probeRegion( + ctx context.Context, + probeType ProbeType, + component string, + regionTarget string, + addrs []*allregions.EdgeAddr, + probeFn func(*allregions.EdgeAddr) CheckResult, +) CheckResult { + if len(addrs) == 0 { + return CheckResult{ + Type: probeType, + Component: component, + Target: regionTarget, + ProbeStatus: Skip, + Details: "No suitable address found for configured IP version", + } + } + + best := probeWithRetry(ctx, addrs[0], probeFn) + for _, addr := range addrs[1:] { + if r := probeWithRetry(ctx, addr, probeFn); r.ProbeStatus == Pass { + best = r + } + } + best.Target = regionTarget + return best +} + +// probeManagementAPIWithRetry runs the Cloudflare API reachability probe with retry. +func probeManagementAPIWithRetry(ctx context.Context, dialer ManagementDialer) CheckResult { + var r CheckResult + withRetry(ctx, maxRetries, func() bool { + r = probeManagementAPI(ctx, dialer) + return r.ProbeStatus == Pass + }) + return r +} + +// probeWithRetry calls probeFn on addr with exponential-backoff retry up to +// maxRetries times, stopping as soon as the probe passes. +func probeWithRetry(ctx context.Context, addr *allregions.EdgeAddr, probeFn func(*allregions.EdgeAddr) CheckResult) CheckResult { + var r CheckResult + withRetry(ctx, maxRetries, func() bool { + r = probeFn(addr) + return r.ProbeStatus == Pass + }) + return r +} + +// addrsByRegion returns the addresses to probe for each DNS-resolved region, +// preserving the per-region grouping. Each inner slice contains at most one V4 +// and one V6 address (subject to ipVersion). +func addrsByRegion(addrGroups [][]*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) [][]*allregions.EdgeAddr { + perRegion := make([][]*allregions.EdgeAddr, 0, len(addrGroups)) + for _, group := range addrGroups { + v4, v6 := addrsByFamily(group, ipVersion) + var addrs []*allregions.EdgeAddr + if v4 != nil { + addrs = append(addrs, v4) + } + if v6 != nil { + addrs = append(addrs, v6) + } + perRegion = append(perRegion, addrs) + } + return perRegion +} + +// dnsTargets extracts the Target hostname from each DNS CheckResult so that +// transport probe rows reuse the same region hostnames. +func dnsTargets(dnsResults []CheckResult) []string { + targets := make([]string, len(dnsResults)) + for i, r := range dnsResults { + targets[i] = r.Target + } + return targets +} + +// skipResultsForRegions returns one skip CheckResult per DNS region, using each +// region's hostname as the Target so the output table row aligns with its DNS row. +func skipResultsForRegions(dnsResults []CheckResult, probeType ProbeType, component string) []CheckResult { + results := make([]CheckResult, len(dnsResults)) + for i, dns := range dnsResults { + results[i] = skipResult(probeType, component, dns.Target) + } + return results +} + +// worstStatus returns the most severe Status across a slice of CheckResults. +// Fail > Pass > Skip. Used to determine whether a transport type as a whole +// should be considered failed (any region failing = transport fails). +func worstStatus(results []CheckResult) Status { + worst := Skip + for _, r := range results { + if severity(r.ProbeStatus) > severity(worst) { + worst = r.ProbeStatus + } + } + return worst +} + +// severity maps a Status to a comparable integer so that worse outcomes rank higher. +func severity(s Status) int { + switch s { + case Fail: + return 2 + case Pass: + return 1 + case Skip: + return 0 + default: + return 0 + } +} + +// suggestProtocol recommends QUIC when all QUIC region probes passed, HTTP/2 +// when all HTTP/2 probes passed, and nil when neither transport works. +// Any region failing means the transport is treated as failed (worst wins). +func suggestProtocol(quicResults, http2Results []CheckResult) *connection.Protocol { + if len(quicResults) > 0 && worstStatus(quicResults) == Pass { + quic := connection.QUIC + return &quic + } + if len(http2Results) > 0 && worstStatus(http2Results) == Pass { + http2 := connection.HTTP2 + return &http2 + } + return nil +} + +// withRetry calls fn up to 1+maxAttempts times, stopping as soon as fn returns +// true. Between attempts it sleeps with exponential backoff bounded by +// maxRetryDelay, and stops early if ctx is done. +func withRetry(ctx context.Context, maxAttempts int, fn func() bool) { + b := backoff.NewWithoutJitter(maxRetryDelay, retryBaseDelay) + for attempt := 0; attempt <= maxAttempts; attempt++ { + if fn() { + return + } + if attempt == maxAttempts { + break + } + timer := time.NewTimer(b.Duration()) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-timer.C: + } + } +} diff --git a/prechecks/checker_test.go b/prechecks/checker_test.go new file mode 100644 index 00000000..f0828575 --- /dev/null +++ b/prechecks/checker_test.go @@ -0,0 +1,471 @@ +package prechecks + +import ( + "errors" + "math" + "net" + "testing" + "time" + + "github.com/google/uuid" + "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/edgediscovery/allregions" + "github.com/cloudflare/cloudflared/mocks" +) + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// twoRegionAddrs returns a two-group [][]*EdgeAddr with one IPv4 address per +// region. Used by tests that only need to exercise the V4 path. +func twoRegionAddrs() [][]*allregions.EdgeAddr { + makeV4 := func(ip string) *allregions.EdgeAddr { + parsed := net.ParseIP(ip) + return &allregions.EdgeAddr{ + TCP: &net.TCPAddr{IP: parsed, Port: 7844}, + UDP: &net.UDPAddr{IP: parsed, Port: 7844}, + IPVersion: allregions.V4, + } + } + return [][]*allregions.EdgeAddr{ + {makeV4("1.2.3.4")}, + {makeV4("5.6.7.8")}, + } +} + +// twoRegionAddrsBothFamilies returns a two-group [][]*EdgeAddr with one IPv4 +// and one IPv6 address per region, used by per-family probe tests. +func twoRegionAddrsBothFamilies() [][]*allregions.EdgeAddr { + makeAddr := func(ip string, v allregions.EdgeIPVersion) *allregions.EdgeAddr { + parsed := net.ParseIP(ip) + return &allregions.EdgeAddr{ + TCP: &net.TCPAddr{IP: parsed, Port: 7844}, + UDP: &net.UDPAddr{IP: parsed, Port: 7844}, + IPVersion: v, + } + } + return [][]*allregions.EdgeAddr{ + {makeAddr("1.2.3.4", allregions.V4), makeAddr("2001:db8::1", allregions.V6)}, + {makeAddr("5.6.7.8", allregions.V4), makeAddr("2001:db8::2", allregions.V6)}, + } +} + +// nopConn is a net.Conn whose Close() is a no-op, used as the success value +// for TCP and management dial mocks. +type nopConn struct{ net.Conn } + +func (nopConn) Close() error { return nil } + +// fakeQUICConn satisfies quic.Connection for tests. Only CloseWithError is +// implemented; the pre-check never opens streams so the rest of the interface +// is unused via the embedded nil. +type fakeQUICConn struct { + quic.Connection +} + +func (*fakeQUICConn) CloseWithError(_ quic.ApplicationErrorCode, _ string) error { return nil } + +// requireStatuses asserts the probe statuses in report.Results match +// expected (in order), failing immediately on length mismatch. +func requireStatuses(t *testing.T, report Report, expected ...Status) { + t.Helper() + require.Len(t, report.Results, len(expected)) + for i, want := range expected { + got := report.Results[i].ProbeStatus + assert.Equalf(t, want, got, + "result[%d] (%s/%s): got %s, want %s", + i, report.Results[i].Component, report.Results[i].Target, got, want) + } +} + +func nopLogger() *zerolog.Logger { + l := zerolog.Nop() + return &l +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +// TestRun_AllPass verifies that when all probes succeed the report contains +// 7 rows: 2 DNS + 2 QUIC (one per region) + 2 HTTP/2 (one per region) + 1 API. +func TestRun_AllPass(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil) + // twoRegionAddrs has 2 regions × 1 V4 address each = 2 dials per transport. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(2) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).Times(2) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // 2 DNS + 2 QUIC + 2 HTTP2 + 1 API = 7 results. + requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) + assert.NotEqual(t, uuid.Nil, report.RunID, "RunID must be set") + require.NotNil(t, report.SuggestedProtocol) + assert.Equal(t, connection.QUIC, *report.SuggestedProtocol) + assert.False(t, report.hasHardFail()) + assert.False(t, report.hasWarn()) +} + +// TestRun_QUICBlocked verifies that when QUIC is blocked on all regions, +// the report is degraded (warn) and HTTP/2 is the suggested protocol. +func TestRun_QUICBlocked(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil) + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).AnyTimes() + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("connection refused")).AnyTimes() + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // 2 DNS Pass + 2 QUIC Fail + 2 HTTP2 Pass + 1 API Pass. + requireStatuses(t, report, Pass, Pass, Fail, Fail, Pass, Pass, Pass) + require.NotNil(t, report.SuggestedProtocol) + assert.Equal(t, connection.HTTP2, *report.SuggestedProtocol) + assert.False(t, report.hasHardFail()) + assert.True(t, report.hasWarn()) +} + +// TestRun_HTTP2Blocked verifies that when HTTP/2 is blocked on all regions, +// the report is degraded (warn) and QUIC is the suggested protocol. +func TestRun_HTTP2Blocked(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil) + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("connection refused")).AnyTimes() + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).AnyTimes() + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // 2 DNS Pass + 2 QUIC Pass + 2 HTTP2 Fail + 1 API Pass. + requireStatuses(t, report, Pass, Pass, Pass, Pass, Fail, Fail, Pass) + require.NotNil(t, report.SuggestedProtocol) + assert.Equal(t, connection.QUIC, *report.SuggestedProtocol) + assert.False(t, report.hasHardFail()) + assert.True(t, report.hasWarn()) +} + +// TestRun_BothTransportsBlocked verifies that when both QUIC and HTTP/2 are +// blocked on all regions it is a hard fail with no suggested protocol. +func TestRun_BothTransportsBlocked(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil) + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("blocked")).AnyTimes() + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("blocked")).AnyTimes() + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // 2 DNS Pass + 2 QUIC Fail + 2 HTTP2 Fail + 1 API Pass. + requireStatuses(t, report, Pass, Pass, Fail, Fail, Fail, Fail, Pass) + assert.Nil(t, report.SuggestedProtocol) + assert.True(t, report.hasHardFail()) +} + +// TestRun_PartialRegionQUICFail verifies "worst wins" semantics: when QUIC +// passes for region1 but fails for region2, QUIC is treated as failed and +// HTTP/2 becomes the suggested protocol. +func TestRun_PartialRegionQUICFail(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + // Two regions: 1.2.3.4 (region1) and 5.6.7.8 (region2). + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil) + + // TCP/HTTP2: both regions pass. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).AnyTimes() + + // QUIC: region1 (1.2.3.4) passes, region2 (5.6.7.8) fails. + region1Addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7844} + region2Addr := &net.UDPAddr{IP: net.ParseIP("5.6.7.8"), Port: 7844} + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), region1Addr.AddrPort(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).AnyTimes() + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), region2Addr.AddrPort(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("connection refused")).AnyTimes() + + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // 2 DNS Pass + QUIC-region1 Pass + QUIC-region2 Fail + 2 HTTP2 Pass + 1 API Pass. + requireStatuses(t, report, Pass, Pass, Pass, Fail, Pass, Pass, Pass) + + // Worst wins: region2 QUIC failed, so QUIC is treated as failed overall. + // HTTP/2 passes on all regions → HTTP/2 is the suggested protocol. + require.NotNil(t, report.SuggestedProtocol) + assert.Equal(t, connection.HTTP2, *report.SuggestedProtocol) + assert.False(t, report.hasHardFail()) + assert.True(t, report.hasWarn()) +} + +// TestRun_DNSFail_SkipsTransports verifies that when DNS fails, transport rows +// are emitted as Skip (one per DNS region) and no transport dials are made. +func TestRun_DNSFail_SkipsTransports(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()). + Return(nil, errors.New("no such host")).AnyTimes() + // Transport dialers must NOT be called when DNS fails. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // DNS failure emits 2 Fail rows (one per default region). + // Transport rows: one skip per DNS region for QUIC and HTTP/2 = 2 QUIC skips + 2 HTTP2 skips. + // 2 DNS Fail + 2 QUIC Skip + 2 HTTP2 Skip + 1 API Pass = 7 results. + require.Len(t, report.Results, 7) + assert.Equal(t, Fail, report.Results[0].ProbeStatus, "DNS region1") + assert.Equal(t, Fail, report.Results[1].ProbeStatus, "DNS region2") + assert.Equal(t, Skip, report.Results[2].ProbeStatus, "QUIC region1 must be skipped") + assert.Equal(t, Skip, report.Results[3].ProbeStatus, "QUIC region2 must be skipped") + assert.Equal(t, Skip, report.Results[4].ProbeStatus, "HTTP/2 region1 must be skipped") + assert.Equal(t, Skip, report.Results[5].ProbeStatus, "HTTP/2 region2 must be skipped") + assert.Equal(t, Pass, report.Results[6].ProbeStatus, "API still runs") + assert.True(t, report.hasHardFail()) +} + +// TestRun_ManagementAPIFail verifies that a Management API failure results +// in a warning (not a hard fail) and QUIC remains the suggested protocol. +func TestRun_ManagementAPIFail(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil) + // twoRegionAddrs has 2 regions × 1 V4 address each; each succeeds on first try. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(2) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).Times(2) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, errors.New("connection refused")).AnyTimes() + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // 2 DNS Pass + 2 QUIC Pass + 2 HTTP2 Pass + 1 API Fail. + requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Fail) + require.NotNil(t, report.SuggestedProtocol) + assert.Equal(t, connection.QUIC, *report.SuggestedProtocol) + assert.False(t, report.hasHardFail()) + assert.True(t, report.hasWarn()) +} + +// TestRun_RegionFlagForwardedToDNS verifies that the --region flag is passed +// verbatim to the DNS resolver and that regional hostnames appear in the results. +func TestRun_RegionFlagForwardedToDNS(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + // The region string must be forwarded verbatim to the DNS resolver. + dns.EXPECT().Resolve("us").Return(twoRegionAddrs(), nil) + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(2) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).Times(2) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Region: "us", Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // DNS rows carry regional hostnames (indices 0 and 1). + assert.Equal(t, "us-region1.v2.argotunnel.com", report.Results[0].Target, "DNS region1") + assert.Equal(t, "us-region2.v2.argotunnel.com", report.Results[1].Target, "DNS region2") + + // Transport rows reuse the same regional hostnames (QUIC: 2,3 / HTTP2: 4,5). + assert.Equal(t, "us-region1.v2.argotunnel.com", report.Results[2].Target, "QUIC region1") + assert.Equal(t, "us-region2.v2.argotunnel.com", report.Results[3].Target, "QUIC region2") + assert.Equal(t, "us-region1.v2.argotunnel.com", report.Results[4].Target, "HTTP2 region1") + assert.Equal(t, "us-region2.v2.argotunnel.com", report.Results[5].Target, "HTTP2 region2") +} + +// TestRun_QUICUsesProbeConnIndex verifies that the QUIC probe always uses the +// reserved sentinel connIndex (math.MaxUint8 = 255) to bypass port-reuse checks. +func TestRun_QUICUsesProbeConnIndex(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil) + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(2) + // connIndex must be the reserved sentinel (math.MaxUint8 = 255), never 0. + // twoRegionAddrs has 2 regions × 1 V4 address each → 2 calls. + quicD.EXPECT().DialQuic( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + gomock.Eq(uint8(math.MaxUint8)), + gomock.Any(), gomock.Any(), + ).Return(&fakeQUICConn{}, nil).Times(2) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) +} + +// TestRun_BothFamiliesProbed verifies that when both V4 and V6 addresses are +// present in the DNS response, both are probed (2 regions × 2 families = 4 dials). +func TestRun_BothFamiliesProbed(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil) + // 2 regions × 2 families = 4 dial calls each for QUIC and HTTP/2. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(4) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).Times(4) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.Auto}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + // 2 DNS + 2 QUIC + 2 HTTP2 + 1 API = 7 results, all passing. + requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) + require.NotNil(t, report.SuggestedProtocol) + assert.Equal(t, connection.QUIC, *report.SuggestedProtocol) +} + +// TestRun_IPv4OnlySkipsV6 verifies that when IPv4Only is configured only V4 +// addresses are probed (2 regions × 1 V4 = 2 dials per transport). +func TestRun_IPv4OnlySkipsV6(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil) + // IPv4Only: only V4 addresses are probed → 2 regions × 1 V4 = 2 calls each. + // V6 addresses must never be dialed. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(2) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).Times(2) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv4Only}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) +} + +// TestRun_IPv6OnlySkipsV4 verifies that when IPv6Only is configured only V6 +// addresses are probed (2 regions × 1 V6 = 2 dials per transport). +func TestRun_IPv6OnlySkipsV4(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + dns := mocks.NewMockDNSResolver(ctrl) + tcp := mocks.NewMockTCPDialer(ctrl) + quicD := mocks.NewMockQUICDialer(ctrl) + mgmt := mocks.NewMockManagementDialer(ctrl) + + dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil) + // IPv6Only: only V6 addresses are probed → 2 regions × 1 V6 = 2 calls each. + // V4 addresses must never be dialled. + tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil).Times(2) + quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(&fakeQUICConn{}, nil).Times(2) + mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nopConn{}, nil) + + report := Run(t.Context(), Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv6Only}, + nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt}) + + requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass) +} diff --git a/prechecks/result_test.go b/prechecks/result_test.go index 8864bbbb..650f9d19 100644 --- a/prechecks/result_test.go +++ b/prechecks/result_test.go @@ -358,7 +358,6 @@ func TestLogEvent_BothTransportsBlocked(t *testing.T) { assert.Equal(t, "fail", entries[3].Status) assert.Equal(t, "Blocked or unreachable", entries[3].Details) - // Summary: hard fail is true. summary := entries[len(entries)-1] require.NotNil(t, summary.HardFail) assert.True(t, *summary.HardFail) @@ -381,7 +380,6 @@ func TestLogEvent_DNSFail(t *testing.T) { assert.Equal(t, "skip", entries[3].Status) assert.Equal(t, "DNS prerequisite failed", entries[3].Details) - // Summary: hard fail is true. summary := entries[len(entries)-1] require.NotNil(t, summary.HardFail) assert.True(t, *summary.HardFail) diff --git a/prechecks/types.go b/prechecks/types.go index b28134ed..7b74a34d 100644 --- a/prechecks/types.go +++ b/prechecks/types.go @@ -57,22 +57,20 @@ type CheckResult struct { Type ProbeType // Component is the human-readable probe category shown in the table header - // column, e.g. "DNS Resolution", "QUIC Connectivity". + // column Component string - // Target is the address or resource that was probed, e.g. - // "region1.v2.argotunnel.com" or "Port 7844 (QUIC)". + // Target is the address or resource that was probed Target string // ProbeStatus is the outcome of the probe. ProbeStatus Status - // Details is a short description of the result shown in the table, e.g. - // "Resolved successfully" or "Handshake failed". + // Details is a short description of the result shown in the table Details string // Action is non-empty when ProbeStatus is Fail and contains a human-readable - // remediation instruction, e.g. "Allow outbound QUIC on port 7844." + // remediation instruction Action string } diff --git a/vendor/github.com/cloudflare/backoff/.travis.yml b/vendor/github.com/cloudflare/backoff/.travis.yml new file mode 100644 index 00000000..3a1e1cb3 --- /dev/null +++ b/vendor/github.com/cloudflare/backoff/.travis.yml @@ -0,0 +1,24 @@ +sudo: false +language: go +go: + - 1.6 + - 1.7 + - tip + +before_script: + - go get github.com/GeertJohan/fgt + - go get github.com/golang/lint/golint + - go get golang.org/x/tools/cmd/goimports + - go get honnef.co/go/staticcheck/cmd/staticcheck + +script: + - find . -name \*.go | xargs fgt goimports -l + - fgt go vet ./... + - fgt golint ./... + - fgt staticcheck ./... + - go test ./... + +notifications: + email: + recipients: + - kyle@cloudflare.com diff --git a/vendor/github.com/cloudflare/backoff/LICENSE b/vendor/github.com/cloudflare/backoff/LICENSE new file mode 100644 index 00000000..965145f7 --- /dev/null +++ b/vendor/github.com/cloudflare/backoff/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2016 CloudFlare Inc. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/cloudflare/backoff/README.md b/vendor/github.com/cloudflare/backoff/README.md new file mode 100644 index 00000000..e1fe9e59 --- /dev/null +++ b/vendor/github.com/cloudflare/backoff/README.md @@ -0,0 +1,83 @@ +# backoff +## Go implementation of "Exponential Backoff And Jitter" + +This package implements the backoff strategy described in the AWS +Architecture Blog article +["Exponential Backoff And Jitter"](http://www.awsarchitectureblog.com/2015/03/backoff.html). Essentially, +the backoff has an interval `time.Duration`; the *nth* call +to backoff will return an a `time.Duration` that is *2 n * +interval*. If jitter is enabled (which is the default behaviour), the +duration is a random value between 0 and *2 n * interval*. +The backoff is configured with a maximum duration that will not be +exceeded; e.g., by default, the longest duration returned is +`backoff.DefaultMaxDuration`. + +## Usage + +A `Backoff` is initialised with a call to `New`. Using zero values +causes it to use `DefaultMaxDuration` and `DefaultInterval` as the +maximum duration and interval. + +``` +package something + +import "github.com/cloudflare/backoff" + +func retryable() { + b := backoff.New(0, 0) + for { + err := someOperation() + if err == nil { + break + } + + log.Printf("error in someOperation: %v", err) + <-time.After(b.Duration()) + } + + log.Printf("succeeded after %d tries", b.Tries()+1) + b.Reset() +} +``` + +It can also be used to rate limit code that should retry infinitely, but which does not +use `Backoff` itself. + +``` +package something + +import ( + "time" + + "github.com/cloudflare/backoff" +) + +func retryable() { + b := backoff.New(0, 0) + b.SetDecay(30 * time.Second) + + for { + // b will reset if someOperation returns later than + // the last call to b.Duration() + 30s. + err := someOperation() + if err == nil { + break + } + + log.Printf("error in someOperation: %v", err) + <-time.After(b.Duration()) + } +} +``` + +## Tunables + +* `NewWithoutJitter` creates a Backoff that doesn't use jitter. + +The default behaviour is controlled by two variables: + +* `DefaultInterval` sets the base interval for backoffs created with + the zero `time.Duration` value in the `Interval` field. +* `DefaultMaxDuration` sets the maximum duration for backoffs created + with the zero `time.Duration` value in the `MaxDuration` field. + diff --git a/vendor/github.com/cloudflare/backoff/backoff.go b/vendor/github.com/cloudflare/backoff/backoff.go new file mode 100644 index 00000000..ee054e15 --- /dev/null +++ b/vendor/github.com/cloudflare/backoff/backoff.go @@ -0,0 +1,197 @@ +// Package backoff contains an implementation of an intelligent backoff +// strategy. It is based on the approach in the AWS architecture blog +// article titled "Exponential Backoff And Jitter", which is found at +// http://www.awsarchitectureblog.com/2015/03/backoff.html. +// +// Essentially, the backoff has an interval `time.Duration`; the nth +// call to backoff will return a `time.Duration` that is 2^n * +// interval. If jitter is enabled (which is the default behaviour), +// the duration is a random value between 0 and 2^n * interval. The +// backoff is configured with a maximum duration that will not be +// exceeded. +// +// The `New` function will attempt to use the system's cryptographic +// random number generator to seed a Go math/rand random number +// source. If this fails, the package will panic on startup. +package backoff + +import ( + "crypto/rand" + "encoding/binary" + "io" + "math" + mrand "math/rand" + "sync" + "time" +) + +var prngMu sync.Mutex +var prng *mrand.Rand + +// DefaultInterval is used when a Backoff is initialised with a +// zero-value Interval. +var DefaultInterval = 5 * time.Minute + +// DefaultMaxDuration is maximum amount of time that the backoff will +// delay for. +var DefaultMaxDuration = 6 * time.Hour + +// A Backoff contains the information needed to intelligently backoff +// and retry operations using an exponential backoff algorithm. It should +// be initialised with a call to `New`. +// +// Only use a Backoff from a single goroutine, it is not safe for concurrent +// access. +type Backoff struct { + // maxDuration is the largest possible duration that can be + // returned from a call to Duration. + maxDuration time.Duration + + // interval controls the time step for backing off. + interval time.Duration + + // noJitter controls whether to use the "Full Jitter" + // improvement to attempt to smooth out spikes in a high + // contention scenario. If noJitter is set to true, no + // jitter will be introduced. + noJitter bool + + // decay controls the decay of n. If it is non-zero, n is + // reset if more than the last backoff + decay has elapsed since + // the last try. + decay time.Duration + + n uint64 + lastTry time.Time +} + +// New creates a new backoff with the specified max duration and +// interval. Zero values may be used to use the default values. +// +// Panics if either max or interval is negative. +func New(max time.Duration, interval time.Duration) *Backoff { + if max < 0 || interval < 0 { + panic("backoff: max or interval is negative") + } + + b := &Backoff{ + maxDuration: max, + interval: interval, + } + b.setup() + return b +} + +// NewWithoutJitter works similarly to New, except that the created +// Backoff will not use jitter. +func NewWithoutJitter(max time.Duration, interval time.Duration) *Backoff { + b := New(max, interval) + b.noJitter = true + return b +} + +func init() { + var buf [8]byte + var n int64 + + _, err := io.ReadFull(rand.Reader, buf[:]) + if err != nil { + panic(err.Error()) + } + + n = int64(binary.LittleEndian.Uint64(buf[:])) + + src := mrand.NewSource(n) + prng = mrand.New(src) +} + +func (b *Backoff) setup() { + if b.interval == 0 { + b.interval = DefaultInterval + } + + if b.maxDuration == 0 { + b.maxDuration = DefaultMaxDuration + } +} + +// Duration returns a time.Duration appropriate for the backoff, +// incrementing the attempt counter. +func (b *Backoff) Duration() time.Duration { + b.setup() + + b.decayN() + + t := b.duration(b.n) + + if b.n < math.MaxUint64 { + b.n++ + } + + if !b.noJitter { + prngMu.Lock() + t = time.Duration(prng.Int63n(int64(t))) + prngMu.Unlock() + } + + return t +} + +// requires b to be locked. +func (b *Backoff) duration(n uint64) (t time.Duration) { + // Saturate pow + pow := time.Duration(math.MaxInt64) + if n < 63 { + pow = 1 << n + } + + t = b.interval * pow + if t/pow != b.interval || t > b.maxDuration { + t = b.maxDuration + } + + return +} + +// Reset resets the attempt counter of a backoff. +// +// It should be called when the rate-limited action succeeds. +func (b *Backoff) Reset() { + b.lastTry = time.Time{} + b.n = 0 +} + +// SetDecay sets the duration after which the try counter will be reset. +// Panics if decay is smaller than 0. +// +// The decay only kicks in if at least the last backoff + decay has elapsed +// since the last try. +func (b *Backoff) SetDecay(decay time.Duration) { + if decay < 0 { + panic("backoff: decay < 0") + } + + b.decay = decay +} + +// requires b to be locked +func (b *Backoff) decayN() { + if b.decay == 0 { + return + } + + if b.lastTry.IsZero() { + b.lastTry = time.Now() + return + } + + lastDuration := b.duration(b.n - 1) + decayed := time.Since(b.lastTry) > lastDuration+b.decay + b.lastTry = time.Now() + + if !decayed { + return + } + + b.n = 0 +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 9db15073..91d8f154 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -10,6 +10,9 @@ github.com/beorn7/perks/quantile # github.com/cespare/xxhash/v2 v2.3.0 ## explicit; go 1.11 github.com/cespare/xxhash/v2 +# github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 +## explicit +github.com/cloudflare/backoff # github.com/coreos/go-oidc/v3 v3.17.0 ## explicit; go 1.24.0 github.com/coreos/go-oidc/v3/oidc