Compare commits

...

20 Commits

Author SHA1 Message Date
Miguel da Costa Martins Marcelino 02eb75b56d TUN-10557: Bump quic-go v0.59.1
This adds back the quic-go bump.
2026-06-18 18:20:39 +00:00
MiguelMarcelino 81a53555aa Release 2026.6.1 2026-06-18 14:39:02 +01:00
Miguel da Costa Martins Marcelino 2bcaf09734 Revert "TUN-10557: Bump quic-go v0.59.1"
This reverts merge request !1850
2026-06-18 13:30:00 +00:00
Miguel da Costa Martins Marcelino 3315fa6e0f TUN-10630: Fix precheck protocol override
As it stands, cloudflared prechecks are not taking the `protocol` flag into consideration and is instead falling back to the default protocol, which is QUIC. Prechecks should report the protocol cloudflared will use, not the default protocol.
2026-06-18 10:56:53 +00:00
Miguel da Costa Martins Marcelino ad11e67340 chore: Fix warnings
Fixing warnings in cloudflared before making any further changes.
2026-06-16 16:53:56 +00:00
João "Pisco" Fernandes 3a60f8ac0f TUN-10612: Add renovate to cloudflared to update distroless images explicitely 2026-06-15 11:39:47 +01:00
lneto 68620efbce TUN-10557: Bump quic-go v0.59.1
Bumps quic-go to v0.59.1 (chungthuang fork rebased from upstream v0.45 onto
v0.59.1). Upstream removed the `logging` package and replaced its
callback-based ConnectionTracer with the structured `qlog`/`qlogwriter` event
API, which required migrating cloudflared's QUIC metrics collection.

Migrations:

- quic/tracing.go: connTracer no longer fills a logging.ConnectionTracer
  callback struct. It implements qlogwriter.Trace + qlogwriter.Recorder and
  dispatches qlog events (PacketSent, PacketReceived, MetricsUpdated, ...) to
  the collector through RecordEvent. NewClientTracer now returns a function
  compatible with quic.Config.Tracer.

- quic/metrics.go: collector methods take qlog types (qlog.Frame,
  qlog.PacketType, qlog.MetricsUpdated, ...) and plain int64 in place of the
  removed logging.ByteCount/Frame/RTTStats/TransportParameters.

- quic/conversion.go: PacketType, PacketDropReason and PacketLossReason are
  strings upstream rather than numeric iotas, so the converters become
  pass-through allowlists. CongestionState is also a string;
  congestionStateToFloat maps it back to the numeric gauge values cloudflared
  exports.

- quic.Connection/quic.Stream became *quic.Conn/*quic.Stream; updated
  ConnWithCloser, SafeStreamCloser and the connection package accordingly.
  Tests and generated mocks (mocks/mock_quic_connection.go) were adapted to
  the new pointer-based API.

Closes TUN-10557
2026-06-12 07:24:26 +01:00
Miguel da Costa Martins Marcelino 4d95ab73f5 TUN-9251: Publish internal image
Publishing internal image in cloudflared. This allows us to remove the dependency from cloudflare/plat/dockerfiles. In addition, our acceptance tests should now be able to use the latest image instead of relying on a fixed version for testing, which will allow us to detect potential failures earlier.
2026-06-11 12:31:50 +00:00
João "Pisco" Fernandes 57f7d693bb Release 2026.6.0 2026-06-08 19:16:09 +01:00
João "Pisco" Fernandes ccffef1179 TUN-10558: Bump go to v1.24.4, x/crypto to v0.52.0 and google.golang.org/grpc to v1.81.1
Closes TUN-10558
2026-06-08 19:15:35 +01:00
Luis Neto 52519f67e8 TUN-10563: introduce QUICConnection interface
The bump of the QUIC library introduces a cyclic dependency between the connection and quic modules hence it is necessary to break this coupling.

Right now, the connection module depends on the quic module for the datagram v2/v3 and to which a QUIC connection (currently an interface) is passed.

As it is there is no issue however, under the hood, interface is a wrapper around an UDP connection and a QUIC connection meaning this type must be exposed to the quic module since the QUIC Connection will no longer be a interface but a struct.

Given the above, these changes introduce an interface, QUICConnection, with the surface used today in cloudflared and a struct, ConnWithCloser, that implements said interface within the quic module.

Closes TUN-10563
2026-06-01 10:08:38 +01:00
João "Pisco" Fernandes 0e84636de9 Release 2026.5.2 2026-05-27 11:15:36 +01:00
Miguel da Costa Martins Marcelino 4177dd6936 TUN-10391: Avoid using fmt.Println
Avoid using fmt.Println and instead switch to logging pre-checks with the provided logger.
2026-05-26 22:04:54 +00:00
João "Pisco" Fernandes f6f60e1059 Release 2026.5.1 2026-05-25 10:32:09 +01:00
Miguel da Costa Martins Marcelino 4494eee13d TUN-10391: Add precheck integration tests
Adding integration tests for cloudflared pre-checks. This tests pre-check functionality to ensure it is working as expected.
2026-05-22 21:58:48 +00:00
Miguel da Costa Martins Marcelino 905d983d14 TUN-10391: Avoid blocking cloudflared due to logging
Pipes have a finite OS buffer (\~64KB Linux, \~4KB macOS, \~4KB Windows). Since nobody was reading stdout/stderr during the process lifetime, cloudflared would block once the buffer filled up. The post-terminate()/read() could only get whatever fit in the buffer, causing truncated logs.

There was also a race between terminate() and read(): the process might not have flushed its final output yet.

We're also deleting `test_default_only`. Since we changed `edge-ip-version` to auto, this test became redundant.
2026-05-22 18:15:54 +00:00
João "Pisco" Fernandes 168f09cb4c fix: Bump go to 1.26.3 and go.opentelemetry.io/otel and go-jose/v4 to fix CVE's 2026-05-22 17:29:40 +01:00
Miguel da Costa Martins Marcelino 0c9014870a TUN-10511: Revise --edge support for pre-checks
Fixing some bugs with DNS targets. Most importantly, these changes also fix some wrong assumptionsmade when trying to add support for the `--edge` flag:

1. Removes `StaticEdgeDNSResolver` in favor `resolveStaticEdge`. Since --edge does not imply resolving DNS, this fixes that assumption.
2. Adds EdgeAddrs, which allows us to skip DNS probes when set. This fixes the targets in the DNS rows.
3. Added a new `ResolvedTarget` struct, which joins addresses with the respective DNS results. This avoids the brittle logic we had before, where we assumed there were always two groups (one for each region) when running probes. So this not only makes the code more extensible in case we want to add more regions in the future but also adds support for multiple targets supplied via `--edge`.
4. Changes the existing nomenclature, going from calling things `region` to `target`. The term `region` works when resolving production regions (region1 and region2), but becomes misleading when we add the logic for `--edge`.

The end result of these changes is that we now see the correct addresses when you supply targets via `--edge`, while also making the code a bit clearer.
2026-05-14 09:06:02 +00:00
Miguel da Costa Martins Marcelino 31de04f858 TUN-10525: Add prechecks kill switch
Instead of having the  --precheck flag in cloudflared, we allow controlling prechecks via a DNS flag, so we can short-circuit this behavior in case anything goes wrong. Although we don't expect pre-checks to add that much traffic, we should still guarantee that we can stop pre-checks in case something goes wrong.
2026-05-13 18:05:11 +00:00
João "Pisco" Fernandes fbfd76089f fix: Update golang.org/x/net to v0.54.0
Check / check (1.22.x, ubuntu-latest) (push) Failing after 5m15s
Semgrep config / semgrep/ci (push) Failing after 1m19s
Check / check (1.22.x, macos-latest) (push) Has been cancelled
Check / check (1.22.x, windows-latest) (push) Has been cancelled
2026-05-13 13:15:15 +01:00
718 changed files with 27163 additions and 112128 deletions
+1 -1
View File
@@ -6,7 +6,7 @@ RUN apt-get update && \
apt-get install --no-install-recommends --allow-downgrades -y \
build-essential \
git \
go-boring=1.26.2-1 \
go-boring=1.26.4-1 \
libffi-dev \
procps \
python3-dev \
+34
View File
@@ -0,0 +1,34 @@
include:
- local: .ci/commons.gitlab-ci.yml
###########################################################################
### Build and Push Internal Image (commit SHA on master, version on tag) ###
###########################################################################
- component: $CI_SERVER_FQDN/cloudflare/ci/docker-image/build-push-image@~latest
inputs:
stage: release-internal
jobPrefix: internal-image
runOnMR: false
runOnBranches: '^master$'
needs:
- generate-internal-image-version
commentImageRefs: false
runner: vm-linux-x86-4cpu-8gb
EXTRA_DIB_ARGS: "--manifest=.docker-images-internal"
###############################################################################
### Generate Internal Image Version File ###
### Uses `git describe`: version tag on tagged commits, SHA-based on master ###
###############################################################################
generate-internal-image-version:
stage: release-internal
image: $BUILD_IMAGE
rules:
- !reference [.default-rules, run-on-master]
needs:
- ci-image-get-image-ref
script:
- make generate-internal-image-version
artifacts:
paths:
- versions-internal
+1 -1
View File
@@ -5,7 +5,7 @@
runner: linux-x86-8cpu-16gb
stage: build
golangVersion: "boring-1.26"
imageVersion: "3595-779e088c0ec4@sha256:a9825d640211b76915a60071e9bef3f73ad3572ce770c7c7dd36b3dd3687504c"
imageVersion: "3625-1801d52@sha256:9261597bc2d229c997522848260de758567643d58ae1097196ae368db89a1d0f"
CGO_ENABLED: 1
.default-packaging-job: &packaging-job-defaults
+1 -1
View File
@@ -8,7 +8,7 @@ include:
rules:
- !reference [.default-rules, run-always]
tags:
- windows-x86
- canary-windows-x86
cache: {}
##########################################
+8
View File
@@ -0,0 +1,8 @@
images:
- name: cloudflared-daemon
dockerfile: Dockerfile.$ARCH
context: .
version_file: versions-internal
architectures:
- amd64
- arm64
+6 -1
View File
@@ -1,5 +1,5 @@
variables:
GO_VERSION: "1.26.2"
GO_VERSION: "1.26.4"
MAC_GO_VERSION: "go@$GO_VERSION"
WIN_GO_VERSION: "go$GO_VERSION"
GIT_DEPTH: "0"
@@ -63,6 +63,11 @@ include:
#####################################################
- local: .ci/apt-internal.gitlab-ci.yml
#####################################################
########## Release Internal Docker Image ############
#####################################################
- local: .ci/internal-image.gitlab-ci.yml
#####################################################
############## Manual Claude Review #################
#####################################################
+1 -1
View File
@@ -1,7 +1,7 @@
# use a builder image for building cloudflare
ARG TARGET_GOOS
ARG TARGET_GOARCH
FROM golang:1.26.2 AS builder
FROM golang:1.26.4 AS builder
ENV GO111MODULE=on \
CGO_ENABLED=0 \
TARGET_GOOS=${TARGET_GOOS} \
+2 -2
View File
@@ -1,5 +1,5 @@
# use a builder image for building cloudflare
FROM golang:1.26.2 AS builder
FROM golang:1.26.4 AS builder
ENV GO111MODULE=on \
CGO_ENABLED=0 \
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
@@ -15,7 +15,7 @@ COPY . .
RUN GOOS=linux GOARCH=amd64 make cloudflared
# use a distroless base image with glibc
FROM gcr.io/distroless/base-debian13:nonroot
FROM gcr.io/distroless/base-debian13:nonroot-amd64@sha256:ced0a2b1936b14d5bddc2ee02a807b1586ca6576a967f5b043f4a3301c8a8f6b
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"
+2 -2
View File
@@ -1,5 +1,5 @@
# use a builder image for building cloudflare
FROM golang:1.26.2 AS builder
FROM golang:1.26.4 AS builder
ENV GO111MODULE=on \
CGO_ENABLED=0 \
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
@@ -15,7 +15,7 @@ COPY . .
RUN GOOS=linux GOARCH=arm64 make cloudflared
# use a distroless base image with glibc
FROM gcr.io/distroless/base-debian13:nonroot-arm64
FROM gcr.io/distroless/base-debian13:nonroot-arm64@sha256:9c1ab6a3dbf9e22827b0be4a314d7cfbe008f922b7ca833ed0e5a63318c6169e
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"
+4
View File
@@ -159,6 +159,10 @@ container:
generate-docker-version:
echo latest $(VERSION) > versions
.PHONY: generate-internal-image-version
generate-internal-image-version:
echo $(VERSION) > versions-internal
.PHONY: test
test: vet
+23
View File
@@ -1,3 +1,26 @@
2026.6.1
- 2026-06-18 TUN-10630: Fix precheck protocol override
- 2026-06-18 Revert "TUN-10557: Bump quic-go v0.59.1"
- 2026-06-16 chore: Fix warnings
- 2026-06-15 TUN-10612: Add renovate to cloudflared to update distroless images explicitely
- 2026-06-11 TUN-9251: Publish internal image
- 2026-05-26 TUN-10557: Bump quic-go v0.59.1
2026.6.0
- 2026-06-08 TUN-10558: Bump go to v1.24.4, x/crypto to v0.52.0 and google.golang.org/grpc to v1.81.1
- 2026-06-01 TUN-10563: introduce QUICConnection interface
2026.5.2
- 2026-05-26 TUN-10391: Avoid using fmt.Println
2026.5.1
- 2026-05-22 fix: Bump go to 1.26.3 and go.opentelemetry.io/otel and go-jose/v4 to fix CVE's
- 2026-05-22 TUN-10391: Avoid blocking cloudflared due to logging
- 2026-05-22 TUN-10391: Add precheck integration tests
- 2026-05-14 TUN-10511: Revise --edge support for pre-checks
- 2026-05-13 fix: Update golang.org/x/net to v0.54.0
- 2026-05-13 TUN-10525: Add prechecks kill switch
2026.5.0
- 2026-05-08 Bump golang.org/x/net from v0.40.0 to v0.53.0
- 2026-05-07 TUN-10507: Bump go and go-boring to 1.26.2
+2 -3
View File
@@ -17,8 +17,7 @@ import (
// Websocket is used to carry data via WS binary frames over the tunnel from client to the origin
// This implements the functions for glider proxy (sock5) and the carrier interface
type Websocket struct {
log *zerolog.Logger
isSocks bool
log *zerolog.Logger
}
// NewWSConnection returns a new connection object
@@ -36,7 +35,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro
ws.log.Err(err).Str(LogFieldOriginURL, options.OriginURL).Msg("failed to connect to origin")
return err
}
defer wsConn.Close()
defer func() { _ = wsConn.Close() }()
stream.Pipe(wsConn, conn, ws.log)
return nil
+27 -26
View File
@@ -2,10 +2,11 @@ package carrier
import (
"context"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"fmt"
"math/rand"
"math/big"
"testing"
"time"
@@ -23,28 +24,19 @@ import (
func websocketClientTLSConfig(t *testing.T) *tls.Config {
certPool := x509.NewCertPool()
helloCert, err := tlsconfig.GetHelloCertificateX509()
assert.NoError(t, err)
require.NoError(t, err)
certPool.AddCert(helloCert)
assert.NotNil(t, certPool)
return &tls.Config{RootCAs: certPool}
}
func TestWebsocketHeaders(t *testing.T) {
req := testRequest(t, "http://example.com", nil)
wsHeaders := websocketHeaders(req)
for _, header := range stripWebsocketHeaders {
assert.Empty(t, wsHeaders[header])
}
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
}
func TestServe(t *testing.T) {
log := zerolog.Nop()
shutdownC := make(chan struct{})
errC := make(chan error)
listener, err := hello.CreateTLSListener("localhost:1111")
assert.NoError(t, err)
defer listener.Close()
require.NoError(t, err)
defer func() { _ = listener.Close() }()
go func() {
errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
@@ -56,19 +48,25 @@ func TestServe(t *testing.T) {
assert.NotNil(t, tlsConfig)
d := gws.Dialer{TLSClientConfig: tlsConfig}
conn, resp, err := clientConnect(req, &d)
assert.NoError(t, err)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
for i := 0; i < 1000; i++ {
messageSize := rand.Int()%2048 + 1
clientMessage := make([]byte, messageSize)
// rand.Read always returns len(clientMessage) and a nil error
rand.Read(clientMessage)
for range 1000 {
messageSize, err := rand.Int(rand.Reader, big.NewInt(2048))
require.NoError(t, err)
clientMessage := make([]byte, messageSize.Int64()+1)
for i := range clientMessage {
n, err := rand.Int(rand.Reader, big.NewInt(256))
n8 := uint8(n.Uint64()) //nolint:gosec // test-only
require.NoError(t, err)
clientMessage[i] = n8
}
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
assert.NoError(t, err)
require.NoError(t, err)
messageType, message, err := conn.ReadMessage()
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, websocket.BinaryFrame, messageType)
assert.Equal(t, clientMessage, message)
}
@@ -97,27 +95,30 @@ func TestWebsocketWrapper(t *testing.T) {
req := testRequest(t, testAddr, nil)
conn, resp, err := clientConnect(req, &d)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
// Websocket now connected to test server so lets check our wrapper
wrapper := cfwebsocket.GorillaConn{Conn: conn}
buf := make([]byte, 100)
wrapper.Write([]byte("abc"))
_, err = wrapper.Write([]byte("abc"))
require.NoError(t, err)
n, err := wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 3)
require.Equal(t, 3, n)
require.Equal(t, "abc", string(buf[:n]))
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
wrapper.Write([]byte("abc"))
_, err = wrapper.Write([]byte("abc"))
require.NoError(t, err)
buf = buf[:1]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 1)
require.Equal(t, 1, n)
require.Equal(t, "a", string(buf[:n]))
buf = buf[:cap(buf)]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 2)
require.Equal(t, 2, n)
require.Equal(t, "bc", string(buf[:n]))
}
+57
View File
@@ -1,6 +1,9 @@
package cliutil
import (
"strings"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
@@ -57,3 +60,57 @@ func ConfigureLoggingFlags(shouldHide bool) []cli.Flag {
FlagLogOutput,
}
}
// LogTable renders lines inside an ASCII table and logs each rendered row.
func LogTable(log *zerolog.Logger, lines []string, title ...string) {
tableTitle := ""
if len(title) > 0 {
tableTitle = title[0]
}
for _, line := range asciiBox(lines, tableTitle, 2) {
if line != "" {
log.Info().Msg(line)
}
}
}
// asciiBox wraps lines in a bordered ASCII box with an optional title row.
func asciiBox(lines []string, title string, padding int) (box []string) {
maxLen := maxLen(lines, title)
spacer := strings.Repeat(" ", padding)
border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+"
box = append(box, border)
if title != "" {
box = append(box, renderBoxLine(centerLine(title, maxLen), maxLen, spacer))
box = append(box, border)
}
for _, line := range lines {
box = append(box, renderBoxLine(line, maxLen, spacer))
}
box = append(box, border)
return
}
// renderBoxLine pads a single line so it fills the box width.
func renderBoxLine(line string, maxLen int, spacer string) string {
return "|" + spacer + line + strings.Repeat(" ", maxLen-len(line)) + spacer + "|"
}
// centerLine pads line evenly so it is centered within width.
func centerLine(line string, width int) string {
padding := width - len(line)
leftPadding := padding / 2
rightPadding := padding - leftPadding
return strings.Repeat(" ", leftPadding) + line + strings.Repeat(" ", rightPadding)
}
// maxLen returns the longest visible line length including the title.
func maxLen(lines []string, title string) int {
max := len(title)
for _, line := range lines {
if len(line) > max {
max = len(line)
}
}
return max
}
+60
View File
@@ -0,0 +1,60 @@
package cliutil
import (
"bytes"
"encoding/json"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogTableWithoutTitle(t *testing.T) {
t.Parallel()
lines := captureTableLogs(t, []string{"first", "second"})
assert.Equal(t, []string{
"+----------+",
"| first |",
"| second |",
"+----------+",
}, lines)
}
func TestLogTableWithTitle(t *testing.T) {
t.Parallel()
lines := captureTableLogs(t, []string{"first", "second"}, "TT")
assert.Equal(t, []string{
"+----------+",
"| TT |",
"+----------+",
"| first |",
"| second |",
"+----------+",
}, lines)
}
func captureTableLogs(t *testing.T, lines []string, title ...string) []string {
t.Helper()
var buf bytes.Buffer
logger := zerolog.New(&buf)
LogTable(&logger, lines, title...)
// nolint: prealloc
var messages []string
for _, line := range bytes.Split(bytes.TrimSpace(buf.Bytes()), []byte("\n")) {
var entry struct {
Message string `json:"message"`
}
require.NoError(t, json.Unmarshal(line, &entry))
messages = append(messages, entry.Message)
}
return messages
}
-3
View File
@@ -126,9 +126,6 @@ const (
// NoPrechecks is the command line flag to skip connectivity pre-checks at startup.
NoPrechecks = "no-prechecks"
// Prechecks is the command line flag to run connectivity pre-checks at startup.
Prechecks = "prechecks"
// LogLevel is the command line flag for the cloudflared logging level
LogLevel = "loglevel"
+14 -32
View File
@@ -375,17 +375,6 @@ func StartServer(
info.Log(log)
logClientOptions(c, log)
// Run connectivity pre-checks for cloudflared. This runs in a separate
// goroutine, as we want to keep initializing cloudflared while prechecks
// are running.
if c.Bool(cfdflags.Prechecks) && !c.Bool(cfdflags.NoPrechecks) {
resolvedRegion := c.String(cfdflags.Region)
if resolvedRegion == "" && namedTunnel != nil {
resolvedRegion = namedTunnel.Credentials.Endpoint
}
go runPrechecks(c, log, resolvedRegion)
}
// this context drives the server, when it's canceled tunnel and all other components (origins, dns, etc...) should stop
ctx, cancel := context.WithCancel(c.Context)
defer cancel()
@@ -428,6 +417,13 @@ func StartServer(
}
connectorID := tunnelConfig.ClientConfig.ConnectorID
// Run connectivity pre-checks for cloudflared. This runs in a separate
// goroutine, as we want to keep initializing cloudflared while prechecks
// are running. Prechecks are controlled via DNS flag for remote kill-switch capability.
if !tunnelConfig.ClientConfig.ConnectionFeaturesSnapshot().SkipPrechecks && !c.Bool(cfdflags.NoPrechecks) {
go runPrechecks(c, log, tunnelConfig.Region)
}
// Disable ICMP packet routing for quick tunnels
if quickTunnelURL != "" {
tunnelConfig.ICMPRouterServer = nil
@@ -541,21 +537,14 @@ func runPrechecks(c *cli.Context, log *zerolog.Logger, region string) {
}
cfg := prechecks.Config{
Region: region,
IPVersion: ipVersion,
}
// Mirror the static/dynamic edge selection from supervisor/supervisor.go:
// when --edge addresses are provided, bypass DNS discovery entirely.
var dnsResolver prechecks.DNSResolver
if edgeAddrs := c.StringSlice(cfdflags.Edge); len(edgeAddrs) > 0 {
dnsResolver = &prechecks.StaticEdgeDNSResolver{Addrs: edgeAddrs, Log: log}
} else {
dnsResolver = &prechecks.EdgeDNSResolver{Log: log}
Region: region,
IPVersion: ipVersion,
EdgeAddrs: c.StringSlice(cfdflags.Edge),
ProtocolOverride: c.String(cfdflags.Protocol),
}
dialers := prechecks.RunDialers{
DNSResolver: dnsResolver,
DNSResolver: &prechecks.EdgeDNSResolver{Log: log},
TCPDialer: &prechecks.EdgeTCPDialer{},
QUICDialer: &prechecks.EdgeQUICDialer{},
ManagementDialer: &prechecks.NetManagementDialer{Dialer: net.Dialer{}},
@@ -563,8 +552,8 @@ func runPrechecks(c *cli.Context, log *zerolog.Logger, region string) {
report := prechecks.Run(c.Context, c.String(cfdflags.CACert), cfg, log, dialers)
// Output the human-readable table to console
fmt.Println(report.String())
// Output the human-readable table
cliutil.LogTable(log, report.String(), "CONNECTIVITY PRE-CHECKS")
// Also log structured results for log aggregation
report.LogEvent(log)
@@ -946,13 +935,6 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
Value: false,
Hidden: shouldHide,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: cfdflags.Prechecks,
Usage: "Run connectivity pre-checks at startup.",
EnvVars: []string{"TUNNEL_PRECHECKS"},
Value: false,
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.Metrics,
Value: metrics.GetMetricsDefaultAddress(metrics.Runtime),
-1
View File
@@ -252,7 +252,6 @@ func prepareTunnelConfig(
QUICConnectionLevelFlowControlLimit: c.Uint64(flags.QuicConnLevelFlowControlLimit),
QUICStreamLevelFlowControlLimit: c.Uint64(flags.QuicStreamLevelFlowControlLimit),
NoPrechecks: c.Bool(flags.NoPrechecks),
Prechecks: c.Bool(flags.Prechecks),
OriginDNSService: dnsService,
OriginDialerService: originDialerService,
}
+4 -28
View File
@@ -11,6 +11,7 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/connection"
)
@@ -44,7 +45,7 @@ func RunQuickTunnel(sc *subcommandContext) error {
if err != nil {
return errors.Wrap(err, "failed to request quick Tunnel")
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
// This will read the entire response into memory so we can print it in case of error
rsp_body, err := io.ReadAll(resp.Body)
@@ -76,12 +77,10 @@ func RunQuickTunnel(sc *subcommandContext) error {
url = "https://" + url
}
for _, line := range AsciiBox([]string{
cliutil.LogTable(sc.log, []string{
"Your quick Tunnel has been created! Visit it at (it may take some time to be reachable):",
url,
}, 2) {
sc.log.Info().Msg(line)
}
})
if !sc.c.IsSet(flags.Protocol) {
_ = sc.c.Set(flags.Protocol, "quic")
@@ -116,26 +115,3 @@ type QuickTunnel struct {
AccountTag string `json:"account_tag"`
Secret []byte `json:"secret"`
}
// Print out the given lines in a nice ASCII box.
func AsciiBox(lines []string, padding int) (box []string) {
maxLen := maxLen(lines)
spacer := strings.Repeat(" ", padding)
border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+"
box = append(box, border)
for _, line := range lines {
box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|")
}
box = append(box, border)
return
}
func maxLen(lines []string) int {
max := 0
for _, line := range lines {
if len(line) > max {
max = len(line)
}
}
return max
}
+4 -31
View File
@@ -1,10 +1,9 @@
import json
import subprocess
from time import sleep
from constants import MANAGEMENT_HOST_NAME
from setup import get_config_from_file
from util import get_tunnel_connector_id
from util import get_tunnel_connector_id, CloudflaredProcess
SINGLE_CASE_TIMEOUT = 600
@@ -83,38 +82,12 @@ class CloudflaredCli:
def __enter__(self):
self.basecmd += ["run"]
self.process = subprocess.Popen(self.basecmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self.logger.info(f"Run cmd {self.basecmd}")
return self.process
self.cfd = CloudflaredProcess(self.basecmd, allow_input=False, capture_output=True)
return self.cfd
def __exit__(self, exc_type, exc_value, exc_traceback):
terminate_gracefully(self.process, self.logger, self.basecmd)
self.logger.debug(f"{self.basecmd} logs: {self.process.stderr.read()}")
def terminate_gracefully(process, logger, cmd):
process.terminate()
process_terminated = wait_for_terminate(process)
if not process_terminated:
process.kill()
logger.warning(f"{cmd}: cloudflared did not terminate within wait period. Killing process. logs: \
stdout: {process.stdout.read()}, stderr: {process.stderr.read()}")
def wait_for_terminate(opened_subprocess, attempts=10, poll_interval=1):
"""
wait_for_terminate polls the opened_subprocess every x seconds for a given number of attempts.
It returns true if the subprocess was terminated and false if it didn't.
"""
for _ in range(attempts):
if _is_process_stopped(opened_subprocess):
return True
sleep(poll_interval)
return False
def _is_process_stopped(process):
return process.poll() is not None
self.cfd.cleanup()
def cert_path():
+11
View File
@@ -5,6 +5,17 @@ MAX_LOG_LINES = 50
MANAGEMENT_HOST_NAME = "management.argotunnel.com"
# How long to wait for the cloudflared process to exit after SIGTERM before
# sending SIGKILL.
GRACEFUL_SHUTDOWN_TIMEOUT = 10
# How long to wait for each pipe reader thread to finish after the process
# exits.
READER_THREAD_JOIN_TIMEOUT = 5
# How long to wait for an expected log message to appear before giving up.
LOG_POLL_TIMEOUT = 30
# How often to re-check the accumulated log lines while polling.
LOG_POLL_INTERVAL = 0.5
def protocols():
return ["http2", "quic"]
-19
View File
@@ -17,25 +17,6 @@ class TestEdgeDiscovery:
config["edge-ip-version"] = edge_ip_version
return config
@pytest.mark.parametrize("protocol", protocols())
def test_default_only(self, tmp_path, component_tests_config, protocol):
"""
This test runs a tunnel with the default edge-ip-version (auto), which will use
whichever address family the system resolver returns first.
"""
if self.has_ipv6_only():
self.expect_address_connections(
tmp_path, component_tests_config, protocol, None, self.expect_ipv6_address)
elif self.has_ipv4_only():
self.expect_address_connections(
tmp_path, component_tests_config, protocol, None, self.expect_ipv4_address)
elif self.has_dual_stack(address_family_preference=socket.AddressFamily.AF_INET6):
self.expect_address_connections(
tmp_path, component_tests_config, protocol, None, self.expect_ipv6_address)
else:
self.expect_address_connections(
tmp_path, component_tests_config, protocol, None, self.expect_ipv4_address)
@pytest.mark.parametrize("protocol", protocols())
def test_ipv4_only(self, tmp_path, component_tests_config, protocol):
"""
+10 -7
View File
@@ -1,8 +1,9 @@
#!/usr/bin/env python
import json
import os
import time
from constants import MAX_LOG_LINES
from constants import MAX_LOG_LINES, LOG_POLL_INTERVAL, LOG_POLL_TIMEOUT
from util import start_cloudflared, wait_tunnel_ready, send_requests
# Rolling logger rotate log files after 1 MB
@@ -12,12 +13,14 @@ expect_message = "Starting Hello"
def assert_log_to_terminal(cloudflared):
for _ in range(0, MAX_LOG_LINES):
line = cloudflared.stderr.readline()
if not line:
break
if expect_message.encode() in line:
return
# All logs are drained by a background thread into cloudflared.stdout_lines.
# Poll the accumulated lines until the expected message appears.
deadline = time.monotonic() + LOG_POLL_TIMEOUT
while time.monotonic() < deadline:
for line in list(cloudflared.stdout_lines):
if expect_message.encode() in line:
return
time.sleep(LOG_POLL_INTERVAL)
raise Exception(f"terminal log doesn't contain {expect_message}")
+541
View File
@@ -0,0 +1,541 @@
#!/usr/bin/env python
"""
Integration tests for cloudflared connectivity pre-checks (TUN-10391).
Scope
-----
These tests verify the end-to-end behavior of cloudflared pre-checks:
- that the human-readable table written to the log output has the correct
structure and content,
- that structured JSON log lines are emitted with the expected fields, and
- that running the `diag` subcommand against a live tunnel instance produces a
zip archive that contains prechecks.json.
They do NOT cover every failure mode of the precheck logic — those are owned
by the unit tests in prechecks/checker_test.go which use mock dialers.
At the integration level the only reliable way to induce specific failure modes
without real firewall intervention is:
- --edge <unreachable>: StaticEdgeDNSResolver resolves the literal IP
directly (DNS row = PASS), then both QUIC and HTTP/2 probes time out
-> hard fail (both transports blocked).
This does NOT exercise the DNS-failure -> transport-skip path.
DNS failure and Management API failure cannot be triggered via CLI flags alone;
they require network-level intervention outside the component-test harness.
stdout/stderr design
--------------------
The pre-checks table is emitted via cliutil.LogTable, which wraps the content
in an ASCII box and logs each line at Info level through zerolog. zerolog
writes to stderr, which the test harness merges into stdout (stderr=STDOUT in
Popen). We poll a --logfile for the "precheck complete" sentinel before
leaving the `with` block, ensuring the goroutine has finished. We then call
cfd.terminate(). After the `with` block exits, the process is dead and all
output has been captured by CloudflaredProcess's background reader thread. We
read the accumulated lines from cfd.stdout_lines.
Box format (cliutil.asciiBox with padding=2, title="CONNECTIVITY PRE-CHECKS"):
+----...----+
| CONNECTIVITY PRE-CHECKS | (centered title)
+----...----+
| COMPONENT TARGET ... | (content rows)
...
+----...----+
"""
import json
import os
import re
import subprocess
import time
import zipfile as zipfilemod
from constants import METRICS_PORT
from util import LOGGER, start_cloudflared, wait_tunnel_ready
# ASCII box constants (cliutil.asciiBox, padding=2, title="CONNECTIVITY PRE-CHECKS")
BOX_TITLE = "CONNECTIVITY PRE-CHECKS"
BOX_BORDER_RE = re.compile(r"^\+(-+)\+$", re.MULTILINE) # matches +----...----+
COL_HEADER = "COMPONENT" # first word of the column-header row
# zerolog console format: "2006-01-02T15:04:05Z LVL <message>"
_LOG_PREFIX_RE = re.compile(r"^\S+ \w+ ")
# Component names (probes.go: componentXxx)
COMP_DNS = "DNS Resolution"
COMP_QUIC = "UDP Connectivity"
COMP_H2 = "TCP Connectivity"
COMP_API = "Cloudflare API"
# Target labels used in the rendered table.
#
# probeRegion() (checker.go:216) always overwrites the Target field of
# whatever CheckResult the inner probe function returns with the regionTarget
# hostname, so QUIC and HTTP/2 rows carry the same region hostname as the
# corresponding DNS row — not the "Port 7844 (QUIC/HTTP2)" strings that
# targetPortQUIC/targetPortHTTP2 define. Those port-label constants are only
# used in the empty-addrs SKIP branch and inside action message strings.
TARGET_API = "api.cloudflare.com:443"
TARGET_REGION1 = "region1.v2.argotunnel.com"
TARGET_REGION2 = "region2.v2.argotunnel.com"
# Details strings (probes.go: detailsXxx)
DETAILS_DNS_RESOLVED = "DNS Resolved successfully"
DETAILS_QUIC_OK = "QUIC connection successful"
DETAILS_HTTP2_OK = "HTTP/2 connection successful"
DETAILS_API_OK = "API is reachable"
DETAILS_QUIC_FAIL = "QUIC connection failed"
DETAILS_HTTP2_FAIL = "HTTP/2 connection is blocked or unreachable"
# Status labels (result.go: xyzStatus)
PASS = "PASS"
FAIL = "FAIL"
SKIP = "SKIP"
# Action prefixes (result.go: renderActions)
PREFIX_ERROR = "ERROR: "
PREFIX_WARNING = "WARNING: "
# Action messages (probes.go: actionXxx)
ACTION_QUIC_BLOCKED = "Allow outbound QUIC traffic on port 7844 or use HTTP2."
ACTION_HTTP2_BLOCKED = "Allow outbound TCP on port 7844."
# Exact summary lines (result.go: summaryLine)
SUMMARY_HEALTHY = "SUMMARY: Environment is healthy. cloudflared will use 'quic' as primary protocol."
SUMMARY_CRITICAL = "SUMMARY: Environment has critical failures. cloudflared may not be able to establish a tunnel."
# structured log constants (result.go)
LOG_MSG_PRECHECK = "precheck"
LOG_MSG_PRECHECK_COMPLETE = "precheck complete"
STATUS_PASS_LOG = "pass"
UNREACHABLE_EDGE = "192.0.2.1:7844"
# cloudflared dial timeout per probe: 5 s, up to 2 retries -> ~15 s total.
PRECHECK_POLL_TIMEOUT_SECS = 15
PRECHECK_POLL_INTERVAL_SECS = 1
# ---------- helpers ----------
def _poll_log_file_for_precheck_complete(log_file: str, timeout: float) -> list[dict]:
"""
Poll a JSON log file until a 'precheck complete' line appears or timeout
expires. Returns all precheck-related log lines found.
cloudflared's --logfile writes one JSON object per line. Polling keeps
the test fast on healthy networks and still tolerates slow CI hosts.
We re-read from the beginning of the file on every poll because the file
is append-only, small, and tracking a byte offset would add complexity with
no meaningful performance benefit for a ~15 s total window.
"""
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
lines = _read_precheck_log_lines_from_file(log_file)
if any(l.get("message") == LOG_MSG_PRECHECK_COMPLETE for l in lines):
return lines
time.sleep(PRECHECK_POLL_INTERVAL_SECS)
return _read_precheck_log_lines_from_file(log_file)
def _read_precheck_log_lines_from_file(log_file: str) -> list[dict]:
"""Parse all precheck-related JSON log lines from a --logfile path."""
result = []
try:
with open(log_file, "r") as f:
for raw_line in f:
raw_line = raw_line.strip()
if not raw_line:
continue
try:
obj = json.loads(raw_line)
except json.JSONDecodeError:
continue
msg = obj.get("message") or obj.get("msg", "")
if msg in (LOG_MSG_PRECHECK, LOG_MSG_PRECHECK_COMPLETE):
result.append(obj)
except FileNotFoundError:
pass
return result
# stdout table parse
class TableRow:
"""One data row parsed from the rendered precheck table."""
def __init__(self, component: str, target: str, status: str, details: str):
self.component = component
self.target = target
self.status = status
self.details = details
def __repr__(self):
return f"TableRow({self.component!r}, {self.target!r}, {self.status!r}, {self.details!r})"
def _strip_log_prefix(line: str) -> str:
"""Remove the zerolog console prefix ('2006-01-02T15:04:05Z LVL ') if present."""
return _LOG_PREFIX_RE.sub("", line, count=1)
def _unbox_line(line: str) -> str:
"""Strip the box border padding from a content line: '| text |' -> 'text'.
Accepts lines that may still carry a zerolog console prefix; the prefix is
removed before the box delimiters are stripped.
"""
msg = _strip_log_prefix(line)
if msg.startswith("|") and msg.endswith("|"):
return msg[1:-1].strip()
return msg.strip()
def _parse_table(stdout: str) -> list[TableRow]:
"""
Parse the data rows from a precheck table in stdout.
The table is now wrapped in an ASCII box by cliutil.LogTable. Each
content line has the form '| <content> |', optionally preceded by a
zerolog console prefix. We strip both the prefix and the box borders
before splitting on two-or-more spaces (text/tabwriter padding=2).
We skip the column-header row and stop at blank lines, SUMMARY, box
border lines, ERROR, or WARNING lines.
"""
rows = []
in_data = False
for raw_line in stdout.splitlines():
msg = _strip_log_prefix(raw_line)
line = _unbox_line(raw_line)
if line.startswith("COMPONENT"):
in_data = True
continue
if not in_data:
continue
if (line == "" or line.startswith("SUMMARY") or BOX_BORDER_RE.match(msg)
or line.startswith("ERROR") or line.startswith("WARNING")):
in_data = False
continue
parts = re.split(r" +", line.rstrip())
if len(parts) >= 3:
rows.append(TableRow(
component=parts[0],
target=parts[1],
status=parts[2],
details=parts[3] if len(parts) >= 4 else "",
))
return rows
def _rows_for(rows: list[TableRow], component: str) -> list[TableRow]:
return [r for r in rows if r.component == component]
# log assertions
def _assert_precheck_summary_log(
log_lines: list[dict],
*,
hard_fail: bool,
suggested_protocol: str | None = None,
):
"""Assert the 'precheck complete' summary log line has the expected fields."""
summary_lines = [l for l in log_lines if l.get("message") == LOG_MSG_PRECHECK_COMPLETE]
assert len(summary_lines) == 1, \
f"Expected exactly one '{LOG_MSG_PRECHECK_COMPLETE}' log line; got {summary_lines}"
summary = summary_lines[0]
assert summary.get("hard_fail") is hard_fail, \
f"Expected hard_fail={hard_fail} in summary log: {summary}"
if suggested_protocol is not None:
assert summary.get("suggested_protocol") == suggested_protocol, \
(f"Expected suggested_protocol={suggested_protocol!r}; "
f"got {summary.get('suggested_protocol')!r}")
# ---------- Tests ----------
class TestPrechecksHappyPath:
"""
On a healthy connection all probes pass. We assert:
- the full table structure (header, column header, separator)
- every row's component, target, status, and details
- no ERROR/WARNING action lines
- the exact summary line
- the structured log summary (hard_fail=false, suggested_protocol=quic)
"""
def test_prechecks_pass_on_healthy_connection(self, tmp_path, component_tests_config):
log_file = str(tmp_path / "cloudflared.log")
config = component_tests_config({"logfile": log_file})
with start_cloudflared(
tmp_path,
config,
cfd_pre_args=["tunnel", "--ha-connections", "1"],
cfd_args=["run"],
new_process=True,
capture_output=True,
) as cfd:
wait_tunnel_ready(tunnel_url=config.get_url(), require_min_connections=1)
# Poll the log file for the sentinel before signalling the process.
log_lines = _poll_log_file_for_precheck_complete(
log_file, timeout=PRECHECK_POLL_TIMEOUT_SECS
)
# Signal shutdown.
cfd.terminate()
# The process is now dead. All output was captured by the background
# reader thread into cfd.stdout_lines (stderr is merged into stdout).
stdout = b"".join(cfd.stdout_lines).decode(errors="replace")
LOGGER.debug(f"[happy-path] stdout:\n{stdout}")
LOGGER.debug(f"[happy-path] log_lines:\n{log_lines}")
# Strip zerolog console prefixes so pattern matching works on raw messages.
messages = "\n".join(_strip_log_prefix(l) for l in stdout.splitlines())
# ── table structure ──────────────────────────────────────────────────
# zerolog writes to stderr which is merged into stdout by the harness.
# The table is wrapped in an ASCII box by cliutil.LogTable.
assert BOX_TITLE in messages, \
f"Expected box title '{BOX_TITLE}' in output;\ngot:\n{stdout}"
assert COL_HEADER in messages, \
f"Expected column header row in output;\ngot:\n{stdout}"
assert BOX_BORDER_RE.search(messages), \
f"Expected box border line (+---+) in output;\ngot:\n{stdout}"
# ── row content ──────────────────────────────────────────────────────
rows = _parse_table(stdout)
assert len(rows) == 7, \
f"Expected 7 rows (2 DNS + 2 QUIC + 2 HTTP/2 + 1 API); got {len(rows)}: {rows}"
dns_rows = _rows_for(rows, COMP_DNS)
assert len(dns_rows) == 2, f"Expected 2 DNS rows; got {dns_rows}"
assert dns_rows[0].target == TARGET_REGION1
assert dns_rows[1].target == TARGET_REGION2
for r in dns_rows:
assert r.status == PASS, f"DNS row not PASS: {r}"
assert r.details == DETAILS_DNS_RESOLVED, f"DNS row details wrong: {r}"
quic_rows = _rows_for(rows, COMP_QUIC)
assert len(quic_rows) == 2, f"Expected 2 QUIC rows; got {quic_rows}"
assert quic_rows[0].target == TARGET_REGION1, f"QUIC row[0] target wrong: {quic_rows[0]}"
assert quic_rows[1].target == TARGET_REGION2, f"QUIC row[1] target wrong: {quic_rows[1]}"
for r in quic_rows:
assert r.status == PASS, f"QUIC row not PASS: {r}"
assert r.details == DETAILS_QUIC_OK, f"QUIC row details wrong: {r}"
h2_rows = _rows_for(rows, COMP_H2)
assert len(h2_rows) == 2, f"Expected 2 HTTP/2 rows; got {h2_rows}"
assert h2_rows[0].target == TARGET_REGION1, f"HTTP/2 row[0] target wrong: {h2_rows[0]}"
assert h2_rows[1].target == TARGET_REGION2, f"HTTP/2 row[1] target wrong: {h2_rows[1]}"
for r in h2_rows:
assert r.status == PASS, f"HTTP/2 row not PASS: {r}"
assert r.details == DETAILS_HTTP2_OK, f"HTTP/2 row details wrong: {r}"
api_rows = _rows_for(rows, COMP_API)
assert len(api_rows) == 1, f"Expected 1 API row; got {api_rows}"
assert api_rows[0].target == TARGET_API, f"API row target wrong: {api_rows[0]}"
assert api_rows[0].status == PASS, f"API row not PASS: {api_rows[0]}"
assert api_rows[0].details == DETAILS_API_OK, f"API row details wrong: {api_rows[0]}"
# ── no action lines ──────────────────────────────────────────────────
assert PREFIX_ERROR not in messages, f"Unexpected ERROR action:\n{stdout}"
assert PREFIX_WARNING not in messages, f"Unexpected WARNING action:\n{stdout}"
# ── summary line ─────────────────────────────────────────────────────
assert SUMMARY_HEALTHY in messages, \
f"Expected healthy summary;\ngot:\n{stdout}"
# ── structured log ───────────────────────────────────────────────────
assert len(log_lines) > 0, \
"Expected at least one structured precheck log line in log file"
for line in log_lines:
if line.get("message") == LOG_MSG_PRECHECK:
assert line.get("status") == STATUS_PASS_LOG, \
f"Expected status=pass in precheck log line: {line}"
_assert_precheck_summary_log(log_lines, hard_fail=False, suggested_protocol="quic")
class TestPrechecksHardFail:
"""
When --edge points at an unreachable IP, StaticEdgeDNSResolver resolves
the literal address directly (DNS row = PASS), but both transport probes
time out -> hard fail. We assert:
- the full table structure
- DNS row: PASS (the literal IP was resolved)
- QUIC row: FAIL with correct details + ERROR action
- HTTP/2 row: FAIL with correct details + ERROR action
- API row: PASS (api.cloudflare.com:443 is independently reachable)
- the exact critical summary line
- the structured log summary (hard_fail=true)
This test does NOT call wait_tunnel_ready because the tunnel will not
connect to the unreachable address.
"""
def test_prechecks_hard_fail_when_edge_unreachable(self, tmp_path, component_tests_config):
log_file = str(tmp_path / "cloudflared.log")
config = component_tests_config({"logfile": log_file})
with start_cloudflared(
tmp_path,
config,
cfd_pre_args=[
"tunnel",
"--ha-connections", "1",
"--edge", UNREACHABLE_EDGE,
],
cfd_args=["run"],
new_process=True,
capture_output=True,
) as cfd:
log_lines = _poll_log_file_for_precheck_complete(
log_file, timeout=PRECHECK_POLL_TIMEOUT_SECS
)
cfd.terminate()
stdout = b"".join(cfd.stdout_lines).decode(errors="replace")
LOGGER.debug(f"[hard-fail] stdout:\n{stdout}")
LOGGER.debug(f"[hard-fail] log_lines:\n{log_lines}")
# Strip zerolog console prefixes so pattern matching works on raw messages.
messages = "\n".join(_strip_log_prefix(l) for l in stdout.splitlines())
# ── table structure ──────────────────────────────────────────────────
# zerolog writes to stderr which is merged into stdout by the harness.
# The table is wrapped in an ASCII box by cliutil.LogTable.
assert BOX_TITLE in messages, \
f"Expected box title '{BOX_TITLE}' in output;\ngot:\n{stdout}"
assert COL_HEADER in messages, \
f"Expected column header row in output;\ngot:\n{stdout}"
assert BOX_BORDER_RE.search(messages), \
f"Expected box border line (+---+) in output;\ngot:\n{stdout}"
# ── row content ──────────────────────────────────────────────────────
rows = _parse_table(stdout)
assert len(rows) == 4, \
f"Expected 4 rows (1 DNS + 1 QUIC + 1 HTTP/2 + 1 API); got {len(rows)}: {rows}"
dns_rows = _rows_for(rows, COMP_DNS)
assert len(dns_rows) == 1, f"Expected 1 DNS row; got {dns_rows}"
assert dns_rows[0].target == UNREACHABLE_EDGE
assert dns_rows[0].status == PASS, f"DNS row not PASS: {dns_rows[0]}"
assert dns_rows[0].details == DETAILS_DNS_RESOLVED, f"DNS row details wrong: {dns_rows[0]}"
quic_rows = _rows_for(rows, COMP_QUIC)
assert len(quic_rows) == 1, f"Expected 1 QUIC row; got {quic_rows}"
assert quic_rows[0].target == UNREACHABLE_EDGE, f"QUIC row target wrong: {quic_rows[0]}"
assert quic_rows[0].status == FAIL, f"QUIC row not FAIL: {quic_rows[0]}"
assert quic_rows[0].details == DETAILS_QUIC_FAIL, f"QUIC row details wrong: {quic_rows[0]}"
h2_rows = _rows_for(rows, COMP_H2)
assert len(h2_rows) == 1, f"Expected 1 HTTP/2 row; got {h2_rows}"
assert h2_rows[0].target == UNREACHABLE_EDGE, f"HTTP/2 row target wrong: {h2_rows[0]}"
assert h2_rows[0].status == FAIL, f"HTTP/2 row not FAIL: {h2_rows[0]}"
assert h2_rows[0].details == DETAILS_HTTP2_FAIL, f"HTTP/2 row details wrong: {h2_rows[0]}"
api_rows = _rows_for(rows, COMP_API)
assert len(api_rows) == 1, f"Expected 1 API row; got {api_rows}"
assert api_rows[0].target == TARGET_API, f"API row target wrong: {api_rows[0]}"
assert api_rows[0].status == PASS, f"API row not PASS: {api_rows[0]}"
assert api_rows[0].details == DETAILS_API_OK, f"API row details wrong: {api_rows[0]}"
assert f"{PREFIX_ERROR}{ACTION_QUIC_BLOCKED}" in messages, \
f"Expected QUIC ERROR action;\ngot:\n{stdout}"
assert f"{PREFIX_ERROR}{ACTION_HTTP2_BLOCKED}" in messages, \
f"Expected HTTP/2 ERROR action;\ngot:\n{stdout}"
assert SUMMARY_CRITICAL in messages, \
f"Expected critical summary;\ngot:\n{stdout}"
_assert_precheck_summary_log(log_lines, hard_fail=True, suggested_protocol=None)
class TestPreChecksDiag:
"""
Verify that `cloudflared tunnel diag` includes prechecks.json in the
diagnostic zip archive produced against a live tunnel instance.
The precheck job in diagnostic.go is gated on noDiagNetwork; we do NOT
pass --no-diag-network so prechecks.json must be present. We skip the
heavier collectors (logs, metrics, system, runtime) to keep the test fast.
The diag subcommand writes the zip to its current working directory. We
run it with cwd=tmp_path so the archive lands there and is cleaned up
automatically by pytest. We resolve config.cloudflared_binary to an
absolute path before changing cwd, because the binary path may be relative
to the original working directory.
"""
def test_diag_contains_prechecks_json(self, tmp_path, component_tests_config):
config = component_tests_config()
binary = os.path.abspath(config.cloudflared_binary)
with start_cloudflared(
tmp_path,
config,
cfd_pre_args=["tunnel", "--ha-connections", "1"],
cfd_args=["run"],
new_process=True,
capture_output=True,
) as cfd:
wait_tunnel_ready(tunnel_url=config.get_url(), require_min_connections=1)
# Run the diag subcommand as a one-shot process against the
# already-running instance. We skip log/metrics/system/runtime
# collectors; the network collector (which runs prechecks) is left
# enabled.
diag_result = subprocess.run(
[
binary,
"tunnel",
"diag",
"--metrics", f"localhost:{METRICS_PORT}",
"--no-diag-logs",
"--no-diag-metrics",
"--no-diag-system",
"--no-diag-runtime",
],
cwd=str(tmp_path),
capture_output=True,
timeout=60,
)
cfd.terminate()
diag_stdout = diag_result.stdout.decode(errors="replace")
diag_stderr = diag_result.stderr.decode(errors="replace")
LOGGER.debug(f"[diag] stdout:\n{diag_stdout}")
LOGGER.debug(f"[diag] stderr:\n{diag_stderr}")
assert diag_result.returncode == 0, (
f"cloudflared tunnel diag exited with code {diag_result.returncode}\n"
f"stdout:\n{diag_stdout}\nstderr:\n{diag_stderr}"
)
# Locate the zip file written to tmp_path by the diag command.
zip_files = list(tmp_path.glob("cloudflared-diag-*.zip"))
assert len(zip_files) == 1, \
f"Expected exactly one cloudflared-diag-*.zip in {tmp_path}; found {zip_files}"
zip_path = zip_files[0]
with zipfilemod.ZipFile(zip_path) as zf:
names = zf.namelist()
LOGGER.debug(f"[diag] zip contents: {names}")
assert "prechecks.json" in names, \
f"Expected prechecks.json in diag zip; got: {names}"
# Must be valid JSON containing at least the RunID field that
# prechecks.Run() always sets.
with zf.open("prechecks.json") as fh:
data = json.load(fh)
assert "RunID" in data, \
f"Expected RunID key in prechecks.json; got keys: {list(data.keys())}"
+62 -8
View File
@@ -2,6 +2,7 @@ import logging
import os
import platform
import subprocess
import threading
from contextlib import contextmanager
from time import sleep
import sys
@@ -12,7 +13,65 @@ import requests
import yaml
from retrying import retry
from constants import METRICS_PORT, MAX_RETRIES, BACKOFF_SECS
from constants import METRICS_PORT, MAX_RETRIES, BACKOFF_SECS, GRACEFUL_SHUTDOWN_TIMEOUT, READER_THREAD_JOIN_TIMEOUT
class CloudflaredProcess:
"""
Wrapper around a Popen process that continuously drains stdout and stderr
in background threads to prevent OS pipe buffers from filling up and
blocking the child process. Captured output is logged when the process
is cleaned up.
"""
def __init__(self, cmd, allow_input, capture_output):
output = subprocess.PIPE if capture_output else subprocess.DEVNULL
stdin = subprocess.PIPE if allow_input else None
self.process = subprocess.Popen(cmd, stdin=stdin, stdout=output, stderr=subprocess.STDOUT)
self._capture_output = capture_output
self._stdout_lines = []
self._threads = []
if capture_output:
self._threads.append(self._start_reader(self.process.stdout, self._stdout_lines))
@staticmethod
def _start_reader(pipe, sink):
def _drain():
for line in pipe:
sink.append(line)
pipe.close()
t = threading.Thread(target=_drain, daemon=True)
t.start()
return t
def terminate(self):
"""Terminate the process if it is still running."""
if self.process.poll() is None:
self.process.terminate()
def cleanup(self):
"""Terminate, wait for exit, join reader threads, and log output."""
self.terminate()
try:
self.process.wait(timeout=GRACEFUL_SHUTDOWN_TIMEOUT)
except subprocess.TimeoutExpired:
self.process.kill()
self.process.wait()
for t in self._threads:
t.join(timeout=READER_THREAD_JOIN_TIMEOUT)
if self._capture_output:
stdout = b"".join(self._stdout_lines).decode("utf-8", errors="replace")
if stdout:
LOGGER.info(f"cloudflared stdout:\n{stdout}")
@property
def stdout_lines(self):
return self._stdout_lines
# Proxy common Popen attributes so callers can still use the wrapper
# as if it were a Popen (e.g. send_signal, stdin, pid, returncode).
def __getattr__(self, name):
return getattr(self.process, name)
def configure_logger():
logger = logging.getLogger(__name__)
@@ -75,20 +134,15 @@ def cloudflared_cmd(config, config_path, cfd_args, cfd_pre_args, root):
LOGGER.info(f"Run cmd {cmd} with config {config}")
return cmd
@contextmanager
def run_cloudflared_background(cmd, allow_input, capture_output):
output = subprocess.PIPE if capture_output else subprocess.DEVNULL
stdin = subprocess.PIPE if allow_input else None
cfd = None
try:
cfd = subprocess.Popen(cmd, stdin=stdin, stdout=output, stderr=output)
cfd = CloudflaredProcess(cmd, allow_input, capture_output)
yield cfd
finally:
if cfd:
cfd.terminate()
if capture_output:
LOGGER.info(f"cloudflared log: {cfd.stderr.read()}")
cfd.cleanup()
def get_quicktunnel_url():
+3 -1
View File
@@ -84,7 +84,7 @@ type TunnelToken struct {
}
func (t TunnelToken) Credentials() Credentials {
// nolint: gosimple
// nolint: staticcheck
return Credentials{
AccountTag: t.AccountTag,
TunnelSecret: t.TunnelSecret,
@@ -122,6 +122,7 @@ const (
// ShouldFlush returns whether this kind of connection should actively flush data
func (t Type) shouldFlush() bool {
// nolint: exhaustive
switch t {
case TypeWebsocket, TypeTCP, TypeControlStream:
return true
@@ -131,6 +132,7 @@ func (t Type) shouldFlush() bool {
}
func (t Type) String() string {
// nolint: exhaustive
switch t {
case TypeWebsocket:
return "websocket"
+2 -2
View File
@@ -146,8 +146,8 @@ func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
case <-wsCtx.Done():
case <-r.Context().Done():
}
readPipe.Close()
writePipe.Close()
_ = readPipe.Close()
_ = writePipe.Close()
}()
originConn := &echoPipe{reader: readPipe, writer: writePipe}
+3 -18
View File
@@ -13,6 +13,7 @@ import (
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection/dialopts"
cfdquic "github.com/cloudflare/cloudflared/quic"
)
var (
@@ -29,7 +30,7 @@ func DialQuic(
connIndex uint8,
logger *zerolog.Logger,
opts dialopts.DialOpts,
) (quic.Connection, error) {
) (cfdquic.QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, opts, logger)
if err != nil {
return nil, err
@@ -43,11 +44,7 @@ func DialQuic(
}
// wrap the session, so that the UDPConn is closed after session is closed.
conn = &wrapCloseableConnQuicConnection{
conn,
udpConn,
}
return conn, nil
return cfdquic.NewQUICConnection(conn, udpConn)
}
func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, opts dialopts.DialOpts, logger *zerolog.Logger) (*net.UDPConn, error) {
@@ -96,15 +93,3 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.Add
return udpConn, err
}
type wrapCloseableConnQuicConnection struct {
quic.Connection
udpConn *net.UDPConn
}
func (w *wrapCloseableConnQuicConnection) CloseWithError(errorCode quic.ApplicationErrorCode, reason string) error {
err := w.Connection.CloseWithError(errorCode, reason)
_ = w.udpConn.Close()
return err
}
+6 -6
View File
@@ -41,7 +41,7 @@ const (
// quicConnection represents the type that facilitates Proxying via QUIC streams.
type quicConnection struct {
conn quic.Connection
conn cfdquic.QUICConnection
logger *zerolog.Logger
orchestrator Orchestrator
datagramHandler DatagramSessionHandler
@@ -54,10 +54,10 @@ type quicConnection struct {
gracePeriod time.Duration
}
// NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic.
// NewTunnelConnection takes a [cfdquic.QUICConnection] to wrap it for use with cloudflared application logic.
func NewTunnelConnection(
ctx context.Context,
conn quic.Connection,
conn cfdquic.QUICConnection,
connIndex uint8,
orchestrator Orchestrator,
datagramSessionHandler DatagramSessionHandler,
@@ -143,7 +143,7 @@ func (q *quicConnection) Serve(ctx context.Context) error {
}
// serveControlStream will serve the RPC; blocking until the control plane is done.
func (q *quicConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error {
func (q *quicConnection) serveControlStream(ctx context.Context, controlStream *quic.Stream) error {
return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions.ConnectionOptions(), q.orchestrator)
}
@@ -166,10 +166,10 @@ func (q *quicConnection) acceptStream(ctx context.Context) error {
}
}
func (q *quicConnection) runStream(quicStream quic.Stream) {
func (q *quicConnection) runStream(quicStream *quic.Stream) {
ctx := quicStream.Context()
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
defer func() { _ = stream.Close() }()
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
+3 -3
View File
@@ -530,7 +530,7 @@ func TestServeUDPSession(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
// Establish QUIC connection with edge
edgeQUICSessionChan := make(chan quic.Connection)
edgeQUICSessionChan := make(chan *quic.Conn)
go func() {
earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig)
assert.NoError(t, err)
@@ -779,7 +779,7 @@ func TestDialQuicWithSkipPortReuse(t *testing.T) {
<-serverDone
}
func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) {
func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQUICSession cfdquic.QUICConnection, closeType closeReason, expectedReason string, t *testing.T) {
payload := []byte(t.Name())
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
@@ -843,7 +843,7 @@ const (
closedByTimeout
)
func runRPCServer(ctx context.Context, session quic.Connection, sessionRPCServer pogs.SessionManager, configRPCServer pogs.ConfigurationManager, t *testing.T) {
func runRPCServer(ctx context.Context, session cfdquic.QUICConnection, sessionRPCServer pogs.SessionManager, configRPCServer pogs.ConfigurationManager, t *testing.T) {
stream, err := session.AcceptStream(ctx)
require.NoError(t, err)
+8 -13
View File
@@ -9,8 +9,6 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
pkgerrors "github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
@@ -25,7 +23,6 @@ import (
cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
)
@@ -34,9 +31,7 @@ const (
demuxChanCapacity = 16
)
var (
errInvalidDestinationIP = errors.New("unable to parse destination IP")
)
var errInvalidDestinationIP = errors.New("unable to parse destination IP")
// DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming
// connection streams.
@@ -47,7 +42,7 @@ type DatagramSessionHandler interface {
}
type datagramV2Connection struct {
conn quic.Connection
conn cfdquic.QUICConnection
index uint8
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
@@ -69,7 +64,7 @@ type datagramV2Connection struct {
}
func NewDatagramV2Connection(ctx context.Context,
conn quic.Connection,
conn cfdquic.QUICConnection,
originDialer ingress.OriginUDPDialer,
icmpRouter ingress.ICMPRouter,
index uint8,
@@ -116,7 +111,7 @@ func (d *datagramV2Connection) Serve(ctx context.Context) error {
}
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) {
func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) {
traceCtx := tracing.NewTracedContext(ctx, traceContext, q.logger)
ctx, registerSpan := traceCtx.Tracer().Start(traceCtx, "register-session", trace.WithAttributes(
attribute.String("session-id", sessionID.String()),
@@ -128,7 +123,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
if err := q.flowLimiter.Acquire(management.UDP.String()); err != nil {
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
err := errors.Wrap(err, "failed to start udp session due to rate limiting")
tracing.EndWithErrorStatus(registerSpan, err)
return nil, err
}
@@ -166,7 +161,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
if err != nil {
originProxy.Close()
_ = originProxy.Close()
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
tracing.EndWithErrorStatus(registerSpan, err)
q.flowLimiter.Release()
@@ -185,7 +180,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
Msgf("Registered session")
tracing.End(registerSpan)
resp := tunnelpogs.RegisterUdpSessionResponse{
resp := pogs.RegisterUdpSessionResponse{
Spans: traceCtx.GetProtoSpans(),
}
@@ -229,7 +224,7 @@ func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uu
}
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
defer func() { _ = stream.Close() }()
rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout)
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
+2 -62
View File
@@ -1,13 +1,11 @@
package connection
import (
"context"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
@@ -16,73 +14,15 @@ import (
"github.com/cloudflare/cloudflared/mocks"
)
type mockQuicConnection struct{}
func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStream() (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) LocalAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) RemoteAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, s string) error {
return nil
}
func (m *mockQuicConnection) Context() context.Context {
return nil
}
func (m *mockQuicConnection) ConnectionState() quic.ConnectionState {
panic("not meant to be called")
}
func (m *mockQuicConnection) SendDatagram(_ []byte) error {
return nil
}
func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) {
return nil, nil
}
func (m *mockQuicConnection) AddPath(*quic.Transport) (*quic.Path, error) {
return nil, nil
}
func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
log := zerolog.Nop()
conn := &mockQuicConnection{}
ctrl := gomock.NewController(t)
flowLimiterMock := mocks.NewMockLimiter(ctrl)
connMock := mocks.NewMockQUICConnection(ctrl)
datagramConn := NewDatagramV2Connection(
t.Context(),
conn,
connMock,
nil,
nil,
0,
+9 -9
View File
@@ -7,12 +7,12 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
cfdquic "github.com/cloudflare/cloudflared/quic/v3"
cfdquic "github.com/cloudflare/cloudflared/quic"
v3 "github.com/cloudflare/cloudflared/quic/v3"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@@ -22,20 +22,20 @@ var (
)
type datagramV3Connection struct {
conn quic.Connection
conn cfdquic.QUICConnection
index uint8
// datagramMuxer mux/demux datagrams from quic connection
datagramMuxer cfdquic.DatagramConn
metrics cfdquic.Metrics
datagramMuxer v3.DatagramConn
metrics v3.Metrics
logger *zerolog.Logger
}
func NewDatagramV3Connection(ctx context.Context,
conn quic.Connection,
sessionManager cfdquic.SessionManager,
conn cfdquic.QUICConnection,
sessionManager v3.SessionManager,
icmpRouter ingress.ICMPRouter,
index uint8,
metrics cfdquic.Metrics,
metrics v3.Metrics,
logger *zerolog.Logger,
) DatagramSessionHandler {
log := logger.
@@ -43,7 +43,7 @@ func NewDatagramV3Connection(ctx context.Context,
Int(management.EventTypeKey, int(management.UDP)).
Uint8(LogFieldConnIndex, index).
Logger()
datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log)
datagramMuxer := v3.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log)
return &datagramV3Connection{
conn,
+4
View File
@@ -42,6 +42,10 @@ type FeatureSnapshot struct {
// We provide the list of features since we need it to send in the ConnectionOptions during connection
// registrations.
FeaturesList []string
// SkipPrechecks indicates when to skip connectivity pre-checks at startup.
// Controlled via DNS TXT record to allow remote kill-switch in case of issues.
SkipPrechecks bool
}
type PostQuantumMode uint8
+9 -1
View File
@@ -24,6 +24,7 @@ const (
type featuresRecord struct {
DatagramV3Percentage uint32 `json:"dv3_2"`
SkipPrechecks bool `json:"skip_prechecks"`
// DatagramV3Percentage int32 `json:"dv3"` // Removed in TUN-9291
// DatagramV3Percentage uint32 `json:"dv3_1"` // Removed in TUN-9883
@@ -89,6 +90,7 @@ func (fs *featureSelector) Snapshot() FeatureSnapshot {
PostQuantum: fs.postQuantumMode(),
DatagramVersion: fs.datagramVersion(),
FeaturesList: fs.clientFeatures(),
SkipPrechecks: fs.prechecksSkip(),
}
}
@@ -121,6 +123,12 @@ func (fs *featureSelector) datagramVersion() DatagramVersion {
return DatagramV2
}
// prechecksSkip returns whether prechecks are enabled via DNS flag.
// Defaults to false if not set in the DNS TXT record.
func (fs *featureSelector) prechecksSkip() bool {
return fs.remoteFeatures.SkipPrechecks
}
// clientFeatures will return the list of currently available features that cloudflared should provide to the edge.
func (fs *featureSelector) clientFeatures() []string {
// Evaluate any remote features along with static feature list to construct the list of features
@@ -186,7 +194,7 @@ func (dr *dnsResolver) lookupRecord(ctx context.Context) ([]byte, error) {
}
if len(records) == 0 {
return nil, fmt.Errorf("No TXT record found for %s to determine which features to opt-in", featureSelectorHostname)
return nil, fmt.Errorf("no TXT record found for %s to determine which features to opt-in", featureSelectorHostname)
}
return []byte(records[0]), nil
+23 -27
View File
@@ -12,7 +12,7 @@ require (
github.com/getsentry/sentry-go v0.43.0
github.com/go-chi/chi/v5 v5.2.2
github.com/go-chi/cors v1.2.1
github.com/go-jose/go-jose/v4 v4.1.3
github.com/go-jose/go-jose/v4 v4.1.4
github.com/gobwas/ws v1.2.1
github.com/google/gopacket v1.1.19
github.com/google/uuid v1.6.0
@@ -23,25 +23,25 @@ require (
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.22.0
github.com/prometheus/client_model v0.6.2
github.com/quic-go/quic-go v0.52.0
github.com/quic-go/quic-go v0.59.1
github.com/rs/zerolog v1.20.0
github.com/shirou/gopsutil/v4 v4.26.3
github.com/stretchr/testify v1.11.1
github.com/urfave/cli/v2 v2.3.0
go.opentelemetry.io/contrib/propagators v0.22.0
go.opentelemetry.io/otel v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.26.0
go.opentelemetry.io/otel/sdk v1.40.0
go.opentelemetry.io/otel/trace v1.40.0
go.opentelemetry.io/proto/otlp v1.2.0
go.opentelemetry.io/otel v1.43.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0
go.opentelemetry.io/otel/sdk v1.43.0
go.opentelemetry.io/otel/trace v1.43.0
go.opentelemetry.io/proto/otlp v1.10.0
go.uber.org/automaxprocs v1.6.0
go.uber.org/mock v0.5.1
golang.org/x/crypto v0.50.0
golang.org/x/net v0.53.0
go.uber.org/mock v0.5.2
golang.org/x/crypto v0.52.0
golang.org/x/net v0.55.0
golang.org/x/sync v0.20.0
golang.org/x/sys v0.43.0
golang.org/x/term v0.42.0
google.golang.org/protobuf v1.36.6
golang.org/x/sys v0.45.0
golang.org/x/term v0.43.0
google.golang.org/protobuf v1.36.11
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
nhooyr.io/websocket v1.8.7
@@ -65,11 +65,9 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/validator/v10 v10.15.1 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/google/pprof v0.0.0-20250418163039-24c5476c6587 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
@@ -77,7 +75,6 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo/v2 v2.23.4 // indirect
github.com/pelletier/go-toml/v2 v2.0.9 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
@@ -89,15 +86,15 @@ require (
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/metric v1.40.0 // indirect
go.opentelemetry.io/otel/metric v1.43.0 // indirect
golang.org/x/arch v0.4.0 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/text v0.36.0 // indirect
golang.org/x/tools v0.43.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250505200425-f936aa4a68b2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250512202823-5a2f75b736a9 // indirect
google.golang.org/grpc v1.72.2 // indirect
golang.org/x/mod v0.35.0 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/text v0.37.0 // indirect
golang.org/x/tools v0.44.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
google.golang.org/grpc v1.81.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
@@ -108,5 +105,4 @@ replace github.com/prometheus/golang_client => github.com/prometheus/golang_clie
replace gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1
// This fork is based on quic-go v0.45
replace github.com/quic-go/quic-go => github.com/chungthuang/quic-go v0.45.1-0.20250428085412-43229ad201fd
replace github.com/quic-go/quic-go => github.com/chungthuang/quic-go v0.45.1-0.20260529212404-a9fddf436fc4 // This fork is based on quic-go v0.59.1
+48 -54
View File
@@ -9,8 +9,8 @@ github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP
github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chungthuang/quic-go v0.45.1-0.20250428085412-43229ad201fd h1:VdYI5zFQ2h1/qzoC6rhyPx479bkF8i177Qpg4Q2n1vk=
github.com/chungthuang/quic-go v0.45.1-0.20250428085412-43229ad201fd/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ=
github.com/chungthuang/quic-go v0.45.1-0.20260529212404-a9fddf436fc4 h1:ZaFGQi6lUEnMyl0DvRy2mEp9u7FP+FrUBr7q+c4U68o=
github.com/chungthuang/quic-go v0.45.1-0.20260529212404-a9fddf436fc4/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 h1:pRcxfaAlK0vR6nOeQs7eAEvjJzdGXl8+KaBlcvpQTyQ=
github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
@@ -60,8 +60,8 @@ github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -79,8 +79,6 @@ github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.15.1 h1:BSe8uhN+xQ4r5guV/ywQI4gO59C2raYcGffYWZEjZzM=
github.com/go-playground/validator/v10 v10.15.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
@@ -104,15 +102,13 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/pprof v0.0.0-20250418163039-24c5476c6587 h1:b/8HpQhvKLSNzH5oTXN2WkNcMl6YB5K3FRbb+i+Ml34=
github.com/google/pprof v0.0.0-20250418163039-24c5476c6587/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d h1:PRDnysJ9dF1vUMmEzBu6aHQeUluSQy4eWH3RsSSy/vI=
github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@@ -150,10 +146,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus=
github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8=
github.com/onsi/gomega v1.36.3 h1:hID7cr8t3Wp26+cYnfcjR6HpJ00fdogN6dqZ1t6IylU=
github.com/onsi/gomega v1.36.3/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
@@ -220,43 +212,43 @@ go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbE
go.opentelemetry.io/contrib/propagators v0.22.0 h1:KGdv58M2//veiYLIhb31mofaI2LgkIPXXAZVeYVyfd8=
go.opentelemetry.io/contrib/propagators v0.22.0/go.mod h1:xGOuXr6lLIF9BXipA4pm6UuOSI0M98U6tsI3khbOiwU=
go.opentelemetry.io/otel v1.0.0-RC2/go.mod h1:w1thVQ7qbAy8MHb0IFj8a5Q2QU0l2ksf8u/CN8m3NOM=
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.26.0 h1:1u/AyyOqAWzy+SkPxDpahCNZParHV8Vid1RnI2clyDE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.26.0/go.mod h1:z46paqbJ9l7c9fIPCXTqTGwhQZ5XoTIsfeFYWboizjs=
go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
go.opentelemetry.io/otel/trace v1.0.0-RC2/go.mod h1:JPQ+z6nNw9mqEGT8o3eoPTdnNI+Aj5JcxEsVGREIAy4=
go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
go.opentelemetry.io/proto/otlp v1.2.0 h1:pVeZGk7nXDC9O2hncA6nHldxEjm6LByfA2aN8IOkz94=
go.opentelemetry.io/proto/otlp v1.2.0/go.mod h1:gGpR8txAl5M03pDhMC79G6SdqNV26naRm/KDsgaHD8A=
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.1 h1:ASgazW/qBmR+A32MYFDB6E2POoTgOwT509VP0CT/fjs=
go.uber.org/mock v0.5.1/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
golang.org/x/arch v0.4.0 h1:A8WCeEWhLwPBKNbFi5Wv5UTCBx5zzubnXDlMOFAzFMc=
golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
@@ -269,31 +261,33 @@ golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c=
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/api v0.0.0-20250505200425-f936aa4a68b2 h1:vPV0tzlsK6EzEDHNNH5sa7Hs9bd7iXR7B1tSiPepkV0=
google.golang.org/genproto/googleapis/api v0.0.0-20250505200425-f936aa4a68b2/go.mod h1:pKLAc5OolXC3ViWGI62vvC0n10CpwAtRcTNCFwTKBEw=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250512202823-5a2f75b736a9 h1:IkAfh6J/yllPtpYFU0zZN1hUPYdT0ogkBT/9hMxHjvg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250512202823-5a2f75b736a9/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
google.golang.org/grpc v1.72.2 h1:TdbGzwb82ty4OusHWepvFWGLgIbNo1/SUynEN0ssqv8=
google.golang.org/grpc v1.72.2/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 h1:tu/dtnW1o3wfaxCOjSLn5IRX4YDcJrtlpzYkhHhGaC4=
google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171/go.mod h1:M5krXqk4GhBKvB596udGL3UyjL4I1+cTbK0orROM9ng=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
+1 -1
View File
@@ -43,7 +43,7 @@ func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *
func (tc *tcpConnection) Write(b []byte) (int, error) {
if tc.writeTimeout > 0 {
if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
if err := tc.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
}
}
+23 -62
View File
@@ -13,7 +13,6 @@ import (
"time"
"github.com/gobwas/ws/wsutil"
gorillaWS "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/proxy"
@@ -61,7 +60,7 @@ func TestStreamTCPConnection(t *testing.T) {
})
errGroup.Go(func() error {
echoTCPOrigin(t, originConn)
originConn.Close()
_ = originConn.Close()
return nil
})
@@ -88,7 +87,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
})
errGroup.Go(func() error {
echoTCPOrigin(t, originConn)
originConn.Close()
_ = originConn.Close()
return nil
})
@@ -117,14 +116,14 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
for _, status := range statusCodes {
handler := func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.Equal(t, []byte(sendMessage), body)
assert.NoError(t, err)
assert.Equal(t, []byte(sendMessage), body)
require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
assert.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
w.Header().Set(echoHeaderName, echoHeaderReturnValue)
w.WriteHeader(status)
w.Write([]byte(echoMessage))
_, _ = w.Write([]byte(echoMessage))
}
origin := httptest.NewServer(http.HandlerFunc(handler))
defer origin.Close()
@@ -156,7 +155,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
errGroup.Go(func() error {
wsForwarderInConn, err := wsForwarderListener.Accept()
require.NoError(t, err)
defer wsForwarderInConn.Close()
defer func() { _ = wsForwarderInConn.Close() }()
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger)
return nil
@@ -171,20 +170,22 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
// Request URL doesn't matter because the transport is using eyeballDialer to connectq
req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage)))
assert.NoError(t, err)
require.NoError(t, err)
defer func() { _ = req.Body.Close() }()
req.Header.Set(echoHeaderName, echoHeaderIncomingValue)
resp, err := transport.RoundTrip(req)
assert.NoError(t, err)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
assert.Equal(t, status, resp.StatusCode)
require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, []byte(echoMessage), body)
wsForwarderOutConn.Close()
edgeConn.Close()
tcpOverWSConn.Close()
_ = wsForwarderOutConn.Close()
_ = edgeConn.Close()
_ = tcpOverWSConn.Close()
require.NoError(t, errGroup.Wait())
}
@@ -205,7 +206,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
go func() {
time.Sleep(time.Millisecond * 10)
// Simulate losing connection to origin
originConn.Close()
_ = originConn.Close()
}()
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger)
@@ -221,11 +222,13 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
for i := 0; i < 50; i++ {
eyeballConn, edgeConn := net.Pipe()
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, server.URL, edgeConn)
assert.NoError(t, err)
require.NoError(t, err)
defer func() { _ = req.Body.Close() }()
resp, err := client.Transport.RoundTrip(req)
assert.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
defer func() { _ = resp.Body.Close() }()
errGroup.Go(func() error {
for {
@@ -261,60 +264,18 @@ func echoWSEyeball(t *testing.T, conn net.Conn) {
assert.NoError(t, conn.Close())
}()
if !assert.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) {
return
}
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
readMsg, err := wsutil.ReadServerBinary(conn)
if !assert.NoError(t, err) {
return
}
require.NoError(t, err)
assert.Equal(t, testResponse, readMsg)
}
func echoWSOrigin(t *testing.T, expectMessages bool) *httptest.Server {
var upgrader = gorillaWS.Upgrader{
ReadBufferSize: 10,
WriteBufferSize: 10,
}
ws := func(w http.ResponseWriter, r *http.Request) {
header := make(http.Header)
for k, vs := range r.Header {
if k == "Test-Cloudflared-Echo" {
header[k] = vs
}
}
conn, err := upgrader.Upgrade(w, r, header)
require.NoError(t, err)
defer conn.Close()
sawMessage := false
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
if expectMessages && !sawMessage {
t.Errorf("unexpected error: %v", err)
}
return
}
assert.Equal(t, testMessage, p)
sawMessage = true
if err := conn.WriteMessage(messageType, testResponse); err != nil {
return
}
}
}
// NewTLSServer starts the server in another thread
return httptest.NewTLSServer(http.HandlerFunc(ws))
}
func echoTCPOrigin(t *testing.T, conn net.Conn) {
readBuffer := make([]byte, len(testMessage))
_, err := conn.Read(readBuffer)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, testMessage, readBuffer)
+427
View File
@@ -0,0 +1,427 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ../quic/quic_connection.go
//
// Generated by this command:
//
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_quic_connection.go -source=../quic/quic_connection.go
//
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
net "net"
reflect "reflect"
quic "github.com/quic-go/quic-go"
gomock "go.uber.org/mock/gomock"
)
// MockQUICConnection is a mock of QUICConnection interface.
type MockQUICConnection struct {
ctrl *gomock.Controller
recorder *MockQUICConnectionMockRecorder
isgomock struct{}
}
// MockQUICConnectionMockRecorder is the mock recorder for MockQUICConnection.
type MockQUICConnectionMockRecorder struct {
mock *MockQUICConnection
}
// NewMockQUICConnection creates a new mock instance.
func NewMockQUICConnection(ctrl *gomock.Controller) *MockQUICConnection {
mock := &MockQUICConnection{ctrl: ctrl}
mock.recorder = &MockQUICConnectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockQUICConnection) EXPECT() *MockQUICConnectionMockRecorder {
return m.recorder
}
// AcceptStream mocks base method.
func (m *MockQUICConnection) AcceptStream(ctx context.Context) (*quic.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptStream", ctx)
ret0, _ := ret[0].(*quic.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcceptStream indicates an expected call of AcceptStream.
func (mr *MockQUICConnectionMockRecorder) AcceptStream(ctx any) *MockQUICConnectionAcceptStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQUICConnection)(nil).AcceptStream), ctx)
return &MockQUICConnectionAcceptStreamCall{Call: call}
}
// MockQUICConnectionAcceptStreamCall wrap *gomock.Call
type MockQUICConnectionAcceptStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionAcceptStreamCall) Return(arg0 *quic.Stream, arg1 error) *MockQUICConnectionAcceptStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionAcceptStreamCall) Do(f func(context.Context) (*quic.Stream, error)) *MockQUICConnectionAcceptStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionAcceptStreamCall) DoAndReturn(f func(context.Context) (*quic.Stream, error)) *MockQUICConnectionAcceptStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// CloseWithError mocks base method.
func (m *MockQUICConnection) CloseWithError(code quic.ApplicationErrorCode, reason string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CloseWithError", code, reason)
ret0, _ := ret[0].(error)
return ret0
}
// CloseWithError indicates an expected call of CloseWithError.
func (mr *MockQUICConnectionMockRecorder) CloseWithError(code, reason any) *MockQUICConnectionCloseWithErrorCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQUICConnection)(nil).CloseWithError), code, reason)
return &MockQUICConnectionCloseWithErrorCall{Call: call}
}
// MockQUICConnectionCloseWithErrorCall wrap *gomock.Call
type MockQUICConnectionCloseWithErrorCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionCloseWithErrorCall) Return(arg0 error) *MockQUICConnectionCloseWithErrorCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionCloseWithErrorCall) Do(f func(quic.ApplicationErrorCode, string) error) *MockQUICConnectionCloseWithErrorCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionCloseWithErrorCall) DoAndReturn(f func(quic.ApplicationErrorCode, string) error) *MockQUICConnectionCloseWithErrorCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// ConnectionState mocks base method.
func (m *MockQUICConnection) ConnectionState() quic.ConnectionState {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(quic.ConnectionState)
return ret0
}
// ConnectionState indicates an expected call of ConnectionState.
func (mr *MockQUICConnectionMockRecorder) ConnectionState() *MockQUICConnectionConnectionStateCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockQUICConnection)(nil).ConnectionState))
return &MockQUICConnectionConnectionStateCall{Call: call}
}
// MockQUICConnectionConnectionStateCall wrap *gomock.Call
type MockQUICConnectionConnectionStateCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionConnectionStateCall) Return(arg0 quic.ConnectionState) *MockQUICConnectionConnectionStateCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionConnectionStateCall) Do(f func() quic.ConnectionState) *MockQUICConnectionConnectionStateCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionConnectionStateCall) DoAndReturn(f func() quic.ConnectionState) *MockQUICConnectionConnectionStateCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Context mocks base method.
func (m *MockQUICConnection) Context() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context.
func (mr *MockQUICConnectionMockRecorder) Context() *MockQUICConnectionContextCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQUICConnection)(nil).Context))
return &MockQUICConnectionContextCall{Call: call}
}
// MockQUICConnectionContextCall wrap *gomock.Call
type MockQUICConnectionContextCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionContextCall) Return(arg0 context.Context) *MockQUICConnectionContextCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionContextCall) Do(f func() context.Context) *MockQUICConnectionContextCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionContextCall) DoAndReturn(f func() context.Context) *MockQUICConnectionContextCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// LocalAddr mocks base method.
func (m *MockQUICConnection) LocalAddr() net.Addr {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LocalAddr")
ret0, _ := ret[0].(net.Addr)
return ret0
}
// LocalAddr indicates an expected call of LocalAddr.
func (mr *MockQUICConnectionMockRecorder) LocalAddr() *MockQUICConnectionLocalAddrCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQUICConnection)(nil).LocalAddr))
return &MockQUICConnectionLocalAddrCall{Call: call}
}
// MockQUICConnectionLocalAddrCall wrap *gomock.Call
type MockQUICConnectionLocalAddrCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionLocalAddrCall) Return(arg0 net.Addr) *MockQUICConnectionLocalAddrCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionLocalAddrCall) Do(f func() net.Addr) *MockQUICConnectionLocalAddrCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionLocalAddrCall) DoAndReturn(f func() net.Addr) *MockQUICConnectionLocalAddrCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// OpenStream mocks base method.
func (m *MockQUICConnection) OpenStream() (*quic.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStream")
ret0, _ := ret[0].(*quic.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStream indicates an expected call of OpenStream.
func (mr *MockQUICConnectionMockRecorder) OpenStream() *MockQUICConnectionOpenStreamCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockQUICConnection)(nil).OpenStream))
return &MockQUICConnectionOpenStreamCall{Call: call}
}
// MockQUICConnectionOpenStreamCall wrap *gomock.Call
type MockQUICConnectionOpenStreamCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionOpenStreamCall) Return(arg0 *quic.Stream, arg1 error) *MockQUICConnectionOpenStreamCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionOpenStreamCall) Do(f func() (*quic.Stream, error)) *MockQUICConnectionOpenStreamCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionOpenStreamCall) DoAndReturn(f func() (*quic.Stream, error)) *MockQUICConnectionOpenStreamCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// OpenStreamSync mocks base method.
func (m *MockQUICConnection) OpenStreamSync(ctx context.Context) (*quic.Stream, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenStreamSync", ctx)
ret0, _ := ret[0].(*quic.Stream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenStreamSync indicates an expected call of OpenStreamSync.
func (mr *MockQUICConnectionMockRecorder) OpenStreamSync(ctx any) *MockQUICConnectionOpenStreamSyncCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQUICConnection)(nil).OpenStreamSync), ctx)
return &MockQUICConnectionOpenStreamSyncCall{Call: call}
}
// MockQUICConnectionOpenStreamSyncCall wrap *gomock.Call
type MockQUICConnectionOpenStreamSyncCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionOpenStreamSyncCall) Return(arg0 *quic.Stream, arg1 error) *MockQUICConnectionOpenStreamSyncCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionOpenStreamSyncCall) Do(f func(context.Context) (*quic.Stream, error)) *MockQUICConnectionOpenStreamSyncCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionOpenStreamSyncCall) DoAndReturn(f func(context.Context) (*quic.Stream, error)) *MockQUICConnectionOpenStreamSyncCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// ReceiveDatagram mocks base method.
func (m *MockQUICConnection) ReceiveDatagram(ctx context.Context) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReceiveDatagram", ctx)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReceiveDatagram indicates an expected call of ReceiveDatagram.
func (mr *MockQUICConnectionMockRecorder) ReceiveDatagram(ctx any) *MockQUICConnectionReceiveDatagramCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveDatagram", reflect.TypeOf((*MockQUICConnection)(nil).ReceiveDatagram), ctx)
return &MockQUICConnectionReceiveDatagramCall{Call: call}
}
// MockQUICConnectionReceiveDatagramCall wrap *gomock.Call
type MockQUICConnectionReceiveDatagramCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionReceiveDatagramCall) Return(arg0 []byte, arg1 error) *MockQUICConnectionReceiveDatagramCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionReceiveDatagramCall) Do(f func(context.Context) ([]byte, error)) *MockQUICConnectionReceiveDatagramCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionReceiveDatagramCall) DoAndReturn(f func(context.Context) ([]byte, error)) *MockQUICConnectionReceiveDatagramCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// RemoteAddr mocks base method.
func (m *MockQUICConnection) RemoteAddr() net.Addr {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoteAddr")
ret0, _ := ret[0].(net.Addr)
return ret0
}
// RemoteAddr indicates an expected call of RemoteAddr.
func (mr *MockQUICConnectionMockRecorder) RemoteAddr() *MockQUICConnectionRemoteAddrCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQUICConnection)(nil).RemoteAddr))
return &MockQUICConnectionRemoteAddrCall{Call: call}
}
// MockQUICConnectionRemoteAddrCall wrap *gomock.Call
type MockQUICConnectionRemoteAddrCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionRemoteAddrCall) Return(arg0 net.Addr) *MockQUICConnectionRemoteAddrCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionRemoteAddrCall) Do(f func() net.Addr) *MockQUICConnectionRemoteAddrCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionRemoteAddrCall) DoAndReturn(f func() net.Addr) *MockQUICConnectionRemoteAddrCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// SendDatagram mocks base method.
func (m *MockQUICConnection) SendDatagram(payload []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendDatagram", payload)
ret0, _ := ret[0].(error)
return ret0
}
// SendDatagram indicates an expected call of SendDatagram.
func (mr *MockQUICConnectionMockRecorder) SendDatagram(payload any) *MockQUICConnectionSendDatagramCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDatagram", reflect.TypeOf((*MockQUICConnection)(nil).SendDatagram), payload)
return &MockQUICConnectionSendDatagramCall{Call: call}
}
// MockQUICConnectionSendDatagramCall wrap *gomock.Call
type MockQUICConnectionSendDatagramCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnectionSendDatagramCall) Return(arg0 error) *MockQUICConnectionSendDatagramCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnectionSendDatagramCall) Do(f func([]byte) error) *MockQUICConnectionSendDatagramCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnectionSendDatagramCall) DoAndReturn(f func([]byte) error) *MockQUICConnectionSendDatagramCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
+7 -6
View File
@@ -17,12 +17,13 @@ import (
reflect "reflect"
time "time"
quic "github.com/quic-go/quic-go"
quic0 "github.com/quic-go/quic-go"
zerolog "github.com/rs/zerolog"
gomock "go.uber.org/mock/gomock"
dialopts "github.com/cloudflare/cloudflared/connection/dialopts"
allregions "github.com/cloudflare/cloudflared/edgediscovery/allregions"
quic "github.com/cloudflare/cloudflared/quic"
)
// MockDNSResolver is a mock of DNSResolver interface.
@@ -176,10 +177,10 @@ func (m *MockQUICDialer) EXPECT() *MockQUICDialerMockRecorder {
}
// DialQuic mocks base method.
func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.Connection, error) {
func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic0.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.QUICConnection, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialQuic", ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
ret0, _ := ret[0].(quic.Connection)
ret0, _ := ret[0].(quic.QUICConnection)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -197,19 +198,19 @@ type MockQUICDialerDialQuicCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.Connection, arg1 error) *MockQUICDialerDialQuicCall {
func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.QUICConnection, arg1 error) *MockQUICDialerDialQuicCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic0.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.QUICConnection, error)) *MockQUICDialerDialQuicCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic0.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.QUICConnection, error)) *MockQUICDialerDialQuicCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
+2
View File
@@ -5,3 +5,5 @@ package mocks
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_resolvers.go -source=../prechecks/resolvers.go"
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_quic_connection.go -source=../quic/quic_connection.go"
+129 -88
View File
@@ -30,16 +30,17 @@ type RunDialers struct {
ManagementDialer ManagementDialer
}
// TransportResults holds the per-region results for each transport probe type.
// Each slice has one entry per DNS-resolved region, in the same order as dnsResults.
// TransportResults holds the per-target results for each transport probe type.
// Each slice has one entry per resolved target group, in the same order as the
// target labels slice.
type TransportResults struct {
QUIC []CheckResult // one per region
HTTP2 []CheckResult // one per region
ManagementAPI CheckResult // single target, no regions
QUIC []CheckResult // one per target group
HTTP2 []CheckResult // one per target group
ManagementAPI CheckResult // single target, no groups
}
// Collect returns all results as a slice in a consistent order for reporting:
// all QUIC rows first (one per region), then all HTTP2 rows, then Management API.
// all QUIC rows first (one per target), then all HTTP2 rows, then Management API.
func (tr TransportResults) Collect() []CheckResult {
results := make([]CheckResult, 0, len(tr.QUIC)+len(tr.HTTP2)+1)
results = append(results, tr.QUIC...)
@@ -50,8 +51,11 @@ func (tr TransportResults) Collect() []CheckResult {
// Run executes the following connectivity pre-checks:
//
// 1. DNS resolution (sequential transport probes depend on its output).
// 2. QUIC, HTTP/2, and Management API probes run concurrently.
// 1. Edge address resolution — either DNS-based SRV discovery (normal path)
// or direct resolution of --edge addresses (static path). The static path
// skips DNS probe rows entirely since there are no SRV records to validate.
// 2. QUIC, HTTP/2, and Management API probes run concurrently against the
// resolved addresses.
//
// Each failed probe is retried up to maxRetries times with exponential backoff.
// The suite is bounded by cfg.Timeout (defaultTimeout if zero).
@@ -64,19 +68,39 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru
ctx, cancel := context.WithTimeout(ctx, cfg.Timeout)
defer cancel()
// Build TLS configs once per protocol
// Build TLS configs once per protocol.
quicTLSConfig, quicTLSErr := probeTLSConfig(caCert, connection.QUIC)
http2TLSConfig, http2TLSErr := probeTLSConfig(caCert, connection.HTTP2)
// 1) DNS must complete before transport probes know which addresses to dial.
addrGroups, dnsResults := runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region)
// 1) Resolve edge addresses. Each ResolvedTarget bundles its addr group
// with the DNS CheckResult that labels it, keeping the two in sync.
var resolvedTargets []ResolvedTarget
if len(cfg.EdgeAddrs) > 0 {
// Static path: explicit --edge addresses, one ResolvedTarget per addr.
resolvedTargets = resolveStaticEdge(cfg.EdgeAddrs, log)
} else {
// Normal path: SRV-based discovery; DNS rows carry Pass or Fail status.
resolvedTargets = runDNSProbe(ctx, runDialers.DNSResolver, cfg.Region)
}
dnsOK := !slices.ContainsFunc(dnsResults, func(r CheckResult) bool {
return r.ProbeStatus != Pass
// Extract parallel slices for the transport probe layer.
// nolint:prealloc // False positive. The linter is confused by the append used when producing Report.Results
dnsResults := make([]CheckResult, len(resolvedTargets))
perGroupAddrs := make([][]*allregions.EdgeAddr, len(resolvedTargets))
targetLabels := make([]string, len(resolvedTargets))
for i, rt := range resolvedTargets {
dnsResults[i] = rt.DNSResult
perGroupAddrs[i] = rt.Addrs
targetLabels[i] = rt.DNSResult.Target
}
// dnsOK is true when at least one target has addresses to probe.
dnsOK := slices.ContainsFunc(resolvedTargets, func(r ResolvedTarget) bool {
return len(r.Addrs) > 0
})
// 2) Run probes concurrently. Each probe type gets its own buffered channel —
// one send, one receive, no routing or name-parsing required.
// 2) Run transport probes concurrently. Each probe type gets its own
// buffered channel — one send, one receive, no routing required.
var results TransportResults
mgmtCh := make(chan CheckResult)
@@ -85,12 +109,12 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru
}()
if !dnsOK {
// DNS failed: emit one skip row per region so the table stays consistent.
results.QUIC = skipResultsForRegions(dnsResults, ProbeTypeQUIC, componentUDPConnectivity)
results.HTTP2 = skipResultsForRegions(dnsResults, ProbeTypeHTTP2, componentTCPConnectivity)
// No addresses available: emit one skip row per target so the table
// stays consistent with the DNS rows above.
results.QUIC = skipResultsForTargets(dnsResults, ProbeTypeQUIC, componentUDPConnectivity)
results.HTTP2 = skipResultsForTargets(dnsResults, ProbeTypeHTTP2, componentTCPConnectivity)
} else {
perRegionAddrs := addrsByRegion(addrGroups, cfg.IPVersion)
regionTargets := dnsTargets(dnsResults)
filteredAddrs := addrsByGroup(perGroupAddrs, cfg.IPVersion)
quicCh := make(chan []CheckResult, 1)
http2Ch := make(chan []CheckResult, 1)
@@ -99,11 +123,11 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru
if quicTLSErr != nil {
log.Warn().Err(quicTLSErr).Msg("Failed to build QUIC probe TLS config")
quicCh <- tlsConfigErrResults(ProbeTypeQUIC, componentUDPConnectivity,
regionTargets, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, quicTLSErr), actionQUICBlocked)
targetLabels, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, quicTLSErr), actionQUICBlocked)
return
}
quicCh <- probeAllRegions(ctx, ProbeTypeQUIC, componentUDPConnectivity,
perRegionAddrs, regionTargets,
quicCh <- probeAllTargets(ctx, ProbeTypeQUIC, componentUDPConnectivity,
filteredAddrs, targetLabels,
func(addr *allregions.EdgeAddr) CheckResult {
return probeQUIC(ctx, quicTLSConfig, runDialers.QUICDialer, addr, log)
})
@@ -113,11 +137,11 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru
if http2TLSErr != nil {
log.Warn().Err(http2TLSErr).Msg("Failed to build HTTP/2 probe TLS config")
http2Ch <- tlsConfigErrResults(ProbeTypeHTTP2, componentTCPConnectivity,
regionTargets, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, http2TLSErr), actionHTTP2Blocked)
targetLabels, fmt.Sprintf("%s: %v", detailsTLSConfigFailed, http2TLSErr), actionHTTP2Blocked)
return
}
http2Ch <- probeAllRegions(ctx, ProbeTypeHTTP2, componentTCPConnectivity,
perRegionAddrs, regionTargets,
http2Ch <- probeAllTargets(ctx, ProbeTypeHTTP2, componentTCPConnectivity,
filteredAddrs, targetLabels,
func(addr *allregions.EdgeAddr) CheckResult {
return probeHTTP2(ctx, http2TLSConfig, runDialers.TCPDialer, addr)
})
@@ -132,15 +156,15 @@ func Run(ctx context.Context, caCert string, cfg Config, log *zerolog.Logger, ru
return Report{
RunID: runID,
Results: append(dnsResults, results.Collect()...),
SuggestedProtocol: suggestProtocol(results.QUIC, results.HTTP2),
SuggestedProtocol: suggestProtocol(results.QUIC, results.HTTP2, cfg.ProtocolOverride),
}
}
// tlsConfigErrResults returns one Fail CheckResult per region target, used when
// tlsConfigErrResults returns one Fail CheckResult per target, used when
// TLS config construction fails before any dial is attempted.
func tlsConfigErrResults(probeType ProbeType, component string, regionTargets []string, details, action string) []CheckResult {
results := make([]CheckResult, len(regionTargets))
for i, target := range regionTargets {
func tlsConfigErrResults(probeType ProbeType, component string, targets []string, details, action string) []CheckResult {
results := make([]CheckResult, len(targets))
for i, target := range targets {
results[i] = CheckResult{
Type: probeType,
Component: component,
@@ -153,47 +177,32 @@ func tlsConfigErrResults(probeType ProbeType, component string, regionTargets []
return results
}
func runDNSProbe(ctx context.Context, resolver DNSResolver, region string) ([][]*allregions.EdgeAddr, []CheckResult) {
var addrGroups [][]*allregions.EdgeAddr
var dnsResults []CheckResult
withRetry(ctx, maxRetries, func() bool {
addrGroups, dnsResults = probeDNS(resolver, region)
for _, r := range dnsResults {
if r.ProbeStatus == Fail {
return false
}
}
return len(dnsResults) > 0
})
return addrGroups, dnsResults
}
// probeAllRegions probes each region sequentially and returns one CheckResult
// per region. Within each region, all available addresses (V4 and/or V6) are
// tried and the best result is kept.
func probeAllRegions(
// probeAllTargets probes each target group sequentially and returns one
// CheckResult per group. Within each group, all available addresses (V4 and/or
// V6) are tried and the best result is kept.
func probeAllTargets(
ctx context.Context,
probeType ProbeType,
component string,
perRegionAddrs [][]*allregions.EdgeAddr,
regionTargets []string,
perGroupAddrs [][]*allregions.EdgeAddr,
targets []string,
probeFn func(*allregions.EdgeAddr) CheckResult,
) []CheckResult {
results := make([]CheckResult, len(perRegionAddrs))
for i, addrs := range perRegionAddrs {
results[i] = probeRegion(ctx, probeType, component, regionTargets[i], addrs, probeFn)
results := make([]CheckResult, len(perGroupAddrs))
for i, addrs := range perGroupAddrs {
results[i] = probeTarget(ctx, probeType, component, targets[i], addrs, probeFn)
}
return results
}
// probeRegion probes all addresses for a single region (typically one V4 and/or
// one V6) and returns the best result. Any address passing means the region is
// reachable, so Pass beats Fail within a region.
func probeRegion(
// probeTarget probes all addresses for a single target group (typically one V4
// and/or one V6) and returns the best result. Any address passing means the
// target is reachable, so Pass beats Fail within a group.
func probeTarget(
ctx context.Context,
probeType ProbeType,
component string,
regionTarget string,
target string,
addrs []*allregions.EdgeAddr,
probeFn func(*allregions.EdgeAddr) CheckResult,
) CheckResult {
@@ -201,7 +210,7 @@ func probeRegion(
return CheckResult{
Type: probeType,
Component: component,
Target: regionTarget,
Target: target,
ProbeStatus: Skip,
Details: "No suitable address found for configured IP version",
}
@@ -213,7 +222,7 @@ func probeRegion(
best = r
}
}
best.Target = regionTarget
best.Target = target
return best
}
@@ -238,11 +247,11 @@ func probeWithRetry(ctx context.Context, addr *allregions.EdgeAddr, probeFn func
return r
}
// addrsByRegion returns the addresses to probe for each DNS-resolved region,
// preserving the per-region grouping. Each inner slice contains at most one V4
// addrsByGroup returns the addresses to probe for each resolved target group,
// preserving the per-group structure. Each inner slice contains at most one V4
// and one V6 address (subject to ipVersion).
func addrsByRegion(addrGroups [][]*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) [][]*allregions.EdgeAddr {
perRegion := make([][]*allregions.EdgeAddr, 0, len(addrGroups))
func addrsByGroup(addrGroups [][]*allregions.EdgeAddr, ipVersion allregions.ConfigIPVersion) [][]*allregions.EdgeAddr {
perGroup := make([][]*allregions.EdgeAddr, 0, len(addrGroups))
for _, group := range addrGroups {
v4, v6 := addrsByFamily(group, ipVersion)
var addrs []*allregions.EdgeAddr
@@ -252,27 +261,17 @@ func addrsByRegion(addrGroups [][]*allregions.EdgeAddr, ipVersion allregions.Con
if v6 != nil {
addrs = append(addrs, v6)
}
perRegion = append(perRegion, addrs)
perGroup = append(perGroup, addrs)
}
return perRegion
return perGroup
}
// dnsTargets extracts the Target hostname from each DNS CheckResult so that
// transport probe rows reuse the same region hostnames.
func dnsTargets(dnsResults []CheckResult) []string {
targets := make([]string, len(dnsResults))
for i, r := range dnsResults {
targets[i] = r.Target
}
return targets
}
// skipResultsForRegions returns one skip CheckResult per DNS region, using each
// region's hostname as the Target so the output table row aligns with its DNS row.
func skipResultsForRegions(dnsResults []CheckResult, probeType ProbeType, component string) []CheckResult {
results := make([]CheckResult, len(dnsResults))
for i, dns := range dnsResults {
results[i] = skipResult(probeType, component, dns.Target)
// skipResultsForTargets returns one skip CheckResult per entry in results,
// using each entry's Target label so the transport row aligns with its DNS row.
func skipResultsForTargets(targets []CheckResult, probeType ProbeType, component string) []CheckResult {
results := make([]CheckResult, len(targets))
for i, t := range targets {
results[i] = skipResult(probeType, component, t.Target, detailsDNSPrerequisiteFailed)
}
return results
}
@@ -304,10 +303,52 @@ func severity(s Status) int {
}
}
// suggestProtocol recommends QUIC when all QUIC region probes passed, HTTP/2
// when all HTTP/2 probes passed, and nil when neither transport works.
// parseProtocolOverride converts the raw --protocol flag string into a
// *connection.Protocol. It returns nil when the string is empty, "auto", or
// unrecognised, so the probe heuristic is used in those cases. "h2mux" is
// treated as HTTP/2 because both map to the same transport.
func parseProtocolOverride(flag string) *connection.Protocol {
switch flag {
case connection.QUIC.String():
p := connection.QUIC
return &p
case connection.HTTP2.String(), "h2mux":
p := connection.HTTP2
return &p
default:
// "auto", empty, or unknown — no override; let the heuristic decide.
return nil
}
}
// suggestProtocol determines the protocol to report in the pre-check summary.
//
// When the caller has explicitly overridden the protocol via --protocol, that
// choice is honoured when its transport probes produced evidence and did not
// fail.
//
// When there is no override (auto-selection), precedence is QUIC, HTTP/2,
// and nil. A protocol is only suggested if all probes pass.
//
// Any region failing means the transport is treated as failed (worst wins).
func suggestProtocol(quicResults, http2Results []CheckResult) *connection.Protocol {
func suggestProtocol(quicResults, http2Results []CheckResult, overrideFlag string) *connection.Protocol {
if override := parseProtocolOverride(overrideFlag); override != nil {
switch *override {
case connection.QUIC:
// Only report QUIC as the suggested protocol if its probes did not
// all fail — if they did, fall through to the heuristic so the
// summary can report a usable fallback or nil.
if len(quicResults) > 0 && worstStatus(quicResults) != Fail {
return new(connection.QUIC)
}
case connection.HTTP2:
// Same logic for an explicit HTTP/2 override.
if len(http2Results) > 0 && worstStatus(http2Results) != Fail {
return new(connection.HTTP2)
}
}
}
if len(quicResults) > 0 && worstStatus(quicResults) == Pass {
quic := connection.QUIC
return &quic
@@ -320,7 +361,7 @@ func suggestProtocol(quicResults, http2Results []CheckResult) *connection.Protoc
}
// withRetry calls fn up to 1+maxAttempts times, stopping as soon as fn returns
// true. Between attempts it sleeps with exponential backoff bounded by
// true. Between attempts, it sleeps with exponential backoff bounded by
// maxRetryDelay, and stops early if ctx is done.
func withRetry(ctx context.Context, maxAttempts int, fn func() bool) {
b := backoff.NewWithoutJitter(maxRetryDelay, retryBaseDelay)
+294 -39
View File
@@ -8,7 +8,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -67,15 +66,6 @@ type nopConn struct{ net.Conn }
func (nopConn) Close() error { return nil }
// fakeQUICConn satisfies quic.Connection for tests. Only CloseWithError is
// implemented; the pre-check never opens streams so the rest of the interface
// is unused via the embedded nil.
type fakeQUICConn struct {
quic.Connection
}
func (*fakeQUICConn) CloseWithError(_ quic.ApplicationErrorCode, _ string) error { return nil }
// requireStatuses asserts the probe statuses in report.Results match
// expected (in order), failing immediately on length mismatch.
func requireStatuses(t *testing.T, report Report, expected ...Status) {
@@ -94,6 +84,14 @@ func nopLogger() *zerolog.Logger {
return &l
}
// newFakeQUICConn creates a mock QUIC connection with CloseWithError
// expectation pre-configured so gomock does not fail at runtime.
func newFakeQUICConn(ctrl *gomock.Controller) *mocks.MockQUICConnection {
conn := mocks.NewMockQUICConnection(ctrl)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
return conn
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
@@ -108,13 +106,14 @@ func TestRun_AllPass(t *testing.T) {
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil)
// twoRegionAddrs has 2 regions × 1 V4 address each = 2 dials per transport.
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(2)
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&fakeQUICConn{}, nil).Times(2)
Return(fakeQUICConn, nil).Times(2)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
@@ -166,6 +165,7 @@ func TestRun_HTTP2Blocked(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
fakeQUICConn := newFakeQUICConn(ctrl)
dns := mocks.NewMockDNSResolver(ctrl)
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
@@ -175,7 +175,7 @@ func TestRun_HTTP2Blocked(t *testing.T) {
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, errors.New("connection refused")).AnyTimes()
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&fakeQUICConn{}, nil).AnyTimes()
Return(fakeQUICConn, nil).AnyTimes()
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
@@ -229,6 +229,7 @@ func TestRun_PartialRegionQUICFail(t *testing.T) {
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
// Two regions: 1.2.3.4 (region1) and 5.6.7.8 (region2).
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil)
@@ -241,7 +242,7 @@ func TestRun_PartialRegionQUICFail(t *testing.T) {
region1Addr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 7844}
region2Addr := &net.UDPAddr{IP: net.ParseIP("5.6.7.8"), Port: 7844}
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), region1Addr.AddrPort(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&fakeQUICConn{}, nil).AnyTimes()
Return(fakeQUICConn, nil).AnyTimes()
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), region2Addr.AddrPort(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, errors.New("connection refused")).AnyTimes()
@@ -308,13 +309,14 @@ func TestRun_ManagementAPIFail(t *testing.T) {
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil)
// twoRegionAddrs has 2 regions × 1 V4 address each; each succeeds on first try.
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(2)
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&fakeQUICConn{}, nil).Times(2)
Return(fakeQUICConn, nil).Times(2)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, errors.New("connection refused")).AnyTimes()
@@ -339,13 +341,14 @@ func TestRun_RegionFlagForwardedToDNS(t *testing.T) {
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
// The region string must be forwarded verbatim to the DNS resolver.
dns.EXPECT().Resolve("us").Return(twoRegionAddrs(), nil)
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(2)
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&fakeQUICConn{}, nil).Times(2)
Return(fakeQUICConn, nil).Times(2)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
@@ -373,6 +376,7 @@ func TestRun_QUICUsesProbeConnIndex(t *testing.T) {
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil)
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
@@ -383,7 +387,7 @@ func TestRun_QUICUsesProbeConnIndex(t *testing.T) {
gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Eq(uint8(math.MaxUint8)),
gomock.Any(), gomock.Any(),
).Return(&fakeQUICConn{}, nil).Times(2)
).Return(fakeQUICConn, nil).Times(2)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
@@ -401,13 +405,14 @@ func TestRun_BothFamiliesProbed(t *testing.T) {
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil)
// 2 regions × 2 families = 4 dial calls each for QUIC and HTTP/2.
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(4)
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&fakeQUICConn{}, nil).Times(4)
Return(fakeQUICConn, nil).Times(4)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
@@ -420,9 +425,247 @@ func TestRun_BothFamiliesProbed(t *testing.T) {
assert.Equal(t, connection.QUIC, *report.SuggestedProtocol)
}
// TestRun_IPv4OnlySkipsV6 verifies that when IPv4Only is configured only V4
// addresses are probed (2 regions × 1 V4 = 2 dials per transport).
func TestRun_IPv4OnlySkipsV6(t *testing.T) {
// TestRun_IPVersionRestriction verifies that when a single IP family is
// configured, only that family is probed (2 regions × 1 addr = 2 dials per
// transport) and the excluded family is never dialled.
func TestRun_IPVersionRestriction(t *testing.T) {
t.Parallel()
tests := []struct {
name string
ipVersion allregions.ConfigIPVersion
}{
{"IPv4Only skips V6", allregions.IPv4Only},
{"IPv6Only skips V4", allregions.IPv6Only},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
dns := mocks.NewMockDNSResolver(ctrl)
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil)
// 2 regions × 1 addr per restricted family = 2 dials each.
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(2)
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(fakeQUICConn, nil).Times(2)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: tt.ipVersion},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass)
})
}
}
// TestRun_EdgeAddrs_SingleAddr verifies that a single --edge addr bypasses DNS
// probing. The report contains one DNS Skip row, transport rows labeled with
// the raw addr string, and the Management API row.
func TestRun_EdgeAddrs_SingleAddr(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
// DNS resolver must NOT be called when EdgeAddrs is set.
dns := mocks.NewMockDNSResolver(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Times(0)
// One addr resolves to one group → one dial per transport.
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(1)
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(fakeQUICConn, nil).Times(1)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
cfg := Config{
EdgeAddrs: []string{"127.0.0.1:7844"},
Timeout: 2 * time.Second,
IPVersion: allregions.Auto,
}
report := Run(t.Context(), emptyCert, cfg, nopLogger(),
RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 1 DNS Skip + 1 QUIC + 1 HTTP2 + 1 API = 4 results.
requireStatuses(t, report, Pass, Pass, Pass, Pass)
assert.Equal(t, ProbeTypeDNS, report.Results[0].Type, "first row must be DNS skip")
assert.Equal(t, "127.0.0.1:7844", report.Results[1].Target, "QUIC target must be the raw --edge addr")
assert.Equal(t, "127.0.0.1:7844", report.Results[2].Target, "HTTP2 target must be the raw --edge addr")
require.NotNil(t, report.SuggestedProtocol)
assert.Equal(t, connection.QUIC, *report.SuggestedProtocol)
}
// TestRun_EdgeAddrs_MultipleAddrs verifies that multiple --edge addrs produce
// one transport row per addr, each labeled with its original addr string.
func TestRun_EdgeAddrs_MultipleAddrs(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
dns := mocks.NewMockDNSResolver(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Times(0)
// Two addrs → two groups → two dials per transport.
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(2)
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(fakeQUICConn, nil).Times(2)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
cfg := Config{
EdgeAddrs: []string{"127.0.0.1:7844", "127.0.0.2:7844"},
Timeout: 2 * time.Second,
IPVersion: allregions.Auto,
}
report := Run(t.Context(), emptyCert, cfg, nopLogger(),
RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 2 DNS Pass (one per addr) + 2 QUIC + 2 HTTP2 + 1 API = 7 results.
requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass)
assert.Equal(t, ProbeTypeDNS, report.Results[0].Type, "first row must be DNS skip addr1")
assert.Equal(t, "127.0.0.1:7844", report.Results[0].Target, "DNS skip addr1 label")
assert.Equal(t, ProbeTypeDNS, report.Results[1].Type, "second row must be DNS skip addr2")
assert.Equal(t, "127.0.0.2:7844", report.Results[1].Target, "DNS skip addr2 label")
assert.Equal(t, "127.0.0.1:7844", report.Results[2].Target, "QUIC addr1")
assert.Equal(t, "127.0.0.2:7844", report.Results[3].Target, "QUIC addr2")
assert.Equal(t, "127.0.0.1:7844", report.Results[4].Target, "HTTP2 addr1")
assert.Equal(t, "127.0.0.2:7844", report.Results[5].Target, "HTTP2 addr2")
}
// TestRun_EdgeAddrs_UnresolvableAddr verifies that when all --edge addrs fail
// to resolve, the DNS resolver is not called and transport rows are skipped,
// mirroring the DNS skip row.
func TestRun_EdgeAddrs_UnresolvableAddr(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
dns := mocks.NewMockDNSResolver(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Times(0)
// Unresolvable addr → no groups → no transport dials.
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0)
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0)
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
cfg := Config{
EdgeAddrs: []string{"not-a-valid-addr"},
Timeout: 2 * time.Second,
IPVersion: allregions.Auto,
}
report := Run(t.Context(), emptyCert, cfg, nopLogger(),
RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// 1 DNS Fail + 1 QUIC Skip + 1 HTTP2 Skip + 1 API = 4 results.
requireStatuses(t, report, Fail, Skip, Skip, Pass)
assert.Equal(t, ProbeTypeDNS, report.Results[0].Type)
assert.Equal(t, "not-a-valid-addr", report.Results[0].Target)
assert.Equal(t, ProbeTypeQUIC, report.Results[1].Type)
assert.Equal(t, ProbeTypeHTTP2, report.Results[2].Type)
assert.Nil(t, report.SuggestedProtocol)
assert.True(t, report.hasHardFail())
}
// ---------------------------------------------------------------------------
// Protocol override tests
// ---------------------------------------------------------------------------
// TestRun_ProtocolOverride_HTTP2_BothPass verifies that when --protocol http2
// is set and both transports are reachable, the summary reports HTTP/2 (not
// QUIC, which would otherwise win the heuristic).
func TestRun_ProtocolOverride_HTTP2_BothPass(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
dns := mocks.NewMockDNSResolver(ctrl)
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil)
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).AnyTimes()
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(fakeQUICConn, nil).AnyTimes()
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
cfg := Config{
Timeout: 2 * time.Second,
IPVersion: allregions.Auto,
ProtocolOverride: "http2",
}
report := Run(t.Context(), emptyCert, cfg, nopLogger(),
RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
// Both transports pass, but the override must win — HTTP/2 is reported.
require.NotNil(t, report.SuggestedProtocol)
assert.Equal(t, connection.HTTP2, *report.SuggestedProtocol,
"override http2 should be reported even though QUIC probes also passed")
assert.False(t, report.hasHardFail())
}
// TestRun_ProtocolOverride_QUIC_BothPass verifies that when --protocol quic is
// set and both transports are reachable, the summary reports QUIC (same as the
// heuristic would choose, but driven by the override).
func TestRun_ProtocolOverride_QUIC_BothPass(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
dns := mocks.NewMockDNSResolver(ctrl)
tcp := mocks.NewMockTCPDialer(ctrl)
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
fakeQUICConn := newFakeQUICConn(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil)
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).AnyTimes()
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(fakeQUICConn, nil).AnyTimes()
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
cfg := Config{
Timeout: 2 * time.Second,
IPVersion: allregions.Auto,
ProtocolOverride: "quic",
}
report := Run(t.Context(), emptyCert, cfg, nopLogger(),
RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
require.NotNil(t, report.SuggestedProtocol)
assert.Equal(t, connection.QUIC, *report.SuggestedProtocol)
assert.False(t, report.hasHardFail())
}
// TestRun_ProtocolOverride_HTTP2_QUICBlocked verifies that when --protocol http2
// is set and QUIC is blocked, we still report HTTP/2 (not a fallback to the
// heuristic, since the overridden transport is healthy).
func TestRun_ProtocolOverride_HTTP2_QUICBlocked(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
@@ -431,25 +674,31 @@ func TestRun_IPv4OnlySkipsV6(t *testing.T) {
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil)
// IPv4Only: only V4 addresses are probed → 2 regions × 1 V4 = 2 calls each.
// V6 addresses must never be dialed.
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil)
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(2)
Return(nopConn{}, nil).AnyTimes()
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&fakeQUICConn{}, nil).Times(2)
Return(nil, errors.New("blocked")).AnyTimes()
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv4Only},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
cfg := Config{
Timeout: 2 * time.Second,
IPVersion: allregions.Auto,
ProtocolOverride: "http2",
}
report := Run(t.Context(), emptyCert, cfg, nopLogger(),
RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass)
require.NotNil(t, report.SuggestedProtocol)
assert.Equal(t, connection.HTTP2, *report.SuggestedProtocol)
assert.False(t, report.hasHardFail())
}
// TestRun_IPv6OnlySkipsV4 verifies that when IPv6Only is configured only V6
// addresses are probed (2 regions × 1 V6 = 2 dials per transport).
func TestRun_IPv6OnlySkipsV4(t *testing.T) {
// TestRun_ProtocolOverride_HTTP2_BothBlocked verifies that when --protocol http2
// is set but the HTTP/2 transport itself also fails (hard fail), the override
// falls through to the heuristic which returns nil — there is no usable protocol.
func TestRun_ProtocolOverride_HTTP2_BothBlocked(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
@@ -458,18 +707,24 @@ func TestRun_IPv6OnlySkipsV4(t *testing.T) {
quicD := mocks.NewMockQUICDialer(ctrl)
mgmt := mocks.NewMockManagementDialer(ctrl)
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrsBothFamilies(), nil)
// IPv6Only: only V6 addresses are probed → 2 regions × 1 V6 = 2 calls each.
// V4 addresses must never be dialled.
dns.EXPECT().Resolve(gomock.Any()).Return(twoRegionAddrs(), nil)
tcp.EXPECT().DialEdge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil).Times(2)
Return(nil, errors.New("blocked")).AnyTimes()
quicD.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(&fakeQUICConn{}, nil).Times(2)
Return(nil, errors.New("blocked")).AnyTimes()
mgmt.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nopConn{}, nil)
report := Run(t.Context(), emptyCert, Config{Timeout: 2 * time.Second, IPVersion: allregions.IPv6Only},
nopLogger(), RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
cfg := Config{
Timeout: 2 * time.Second,
IPVersion: allregions.Auto,
ProtocolOverride: "http2",
}
report := Run(t.Context(), emptyCert, cfg, nopLogger(),
RunDialers{DNSResolver: dns, TCPDialer: tcp, QUICDialer: quicD, ManagementDialer: mgmt})
requireStatuses(t, report, Pass, Pass, Pass, Pass, Pass, Pass, Pass)
// The overridden transport (HTTP/2) is blocked, so the override cannot be
// honoured and the hard-fail path reports no suggested protocol.
assert.Nil(t, report.SuggestedProtocol)
assert.True(t, report.hasHardFail())
}
+80 -51
View File
@@ -17,6 +17,7 @@ import (
"github.com/cloudflare/cloudflared/connection"
edgedial "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tlsconfig"
)
@@ -25,7 +26,7 @@ const (
// Action messages for each probe outcome.
actionDNSFail = "Ensure your DNS resolver can resolve '%s'. Run: dig A %s @1.1.1.1. If that fails, contact your network administrator."
actionQUICBlocked = "QUIC traffic failed to connect to port 7844."
actionQUICBlocked = "Allow outbound QUIC traffic on port 7844 or use HTTP2."
actionHTTP2Blocked = "Allow outbound TCP on port 7844."
actionAPIUnreachable = "cloudflared will still run, but automatic software updates are unavailable. " +
"Ensure port 443 TCP to api.cloudflare.com is open if you want auto-updates."
@@ -40,18 +41,19 @@ const (
targetPortQUIC = "Port 7844 (QUIC)"
targetPortHTTP2 = "Port 7844 (HTTP/2)"
targetAPI = "api.cloudflare.com:443"
noDNSTarget = "No DNS target (Using edge flag)"
// Details messages for CheckResult.
detailsNoAddressesReturned = "No addresses returned"
detailsResolvedSuccessfully = "Resolved successfully"
detailsHandshakeFailed = "Handshake failed"
detailsHandshakeSuccessful = "Handshake successful"
detailsBlockedOrUnreachable = "Blocked or unreachable"
detailsTLSHandshakeSuccessful = "TLS handshake successful"
detailsConnectionFailed = "Connection failed"
detailsTCPPortReachable = "TCP port reachable (TLS not validated)"
detailsDNSPrerequisiteFailed = "DNS prerequisite failed"
detailsTLSConfigFailed = "TLS configuration failed"
dnsNoAddressesReturned = "No addresses returned"
dnsResolvedSuccessfully = "DNS Resolved successfully"
detailsQUICHandshakeFailed = "QUIC connection failed"
detailsQUICHandshakeSuccessful = "QUIC connection successful"
detailsHTTP2BlockedOrUnreachable = "HTTP/2 connection is blocked or unreachable"
detailsHTTP2HandshakeSuccessful = "HTTP/2 connection successful"
detailsAPIConnectionFailed = "API Connection failed"
detailsApiReachable = "API is reachable"
detailsDNSPrerequisiteFailed = "DNS prerequisite failed"
detailsTLSConfigFailed = "TLS configuration failed"
// Region hostname templates.
region1Global = "region1.v2.argotunnel.com"
@@ -72,20 +74,6 @@ func (r *EdgeDNSResolver) Resolve(region string) ([][]*allregions.EdgeAddr, erro
return allregions.EdgeDiscovery(r.Log, allregions.RegionalServiceName(region))
}
// StaticEdgeDNSResolver implements DNSResolver for the --edge flag path.
type StaticEdgeDNSResolver struct {
Addrs []string
Log *zerolog.Logger
}
func (r *StaticEdgeDNSResolver) Resolve(_ string) ([][]*allregions.EdgeAddr, error) {
resolved := allregions.ResolveAddrs(r.Addrs, r.Log)
if len(resolved) == 0 {
return nil, fmt.Errorf("failed to resolve any edge address")
}
return [][]*allregions.EdgeAddr{resolved}, nil
}
type EdgeTCPDialer struct{}
func (d *EdgeTCPDialer) DialEdge(
@@ -109,7 +97,7 @@ func (d *EdgeQUICDialer) DialQuic(
connIndex uint8,
logger *zerolog.Logger,
opts dialopts.DialOpts,
) (quic.Connection, error) {
) (cfdquic.QUICConnection, error) {
return connection.DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
}
@@ -141,42 +129,47 @@ func probeTLSConfig(caCert string, p connection.Protocol) (*tls.Config, error) {
}
// probeDNS resolves edge addresses for the given region via the supplied
// DNSResolver and returns a CheckResult for each region discovered. If
// resolution fails for all regions, every result will carry StatusFail.
// DNSResolver and returns one ResolvedTarget per discovered region. If
// resolution fails entirely, every ResolvedTarget will carry a Fail DNSResult
// and nil Addrs.
func probeDNS(
resolver DNSResolver,
region string,
) ([][]*allregions.EdgeAddr, []CheckResult) {
) []ResolvedTarget {
region1Target, region2Target := regionTargets(region)
targets := []string{region1Target, region2Target}
addrGroups, err := resolver.Resolve(region)
if err != nil || len(addrGroups) == 0 {
detail := detailsNoAddressesReturned
detail := dnsNoAddressesReturned
if err != nil {
detail = err.Error()
}
return nil, []CheckResult{
newDNSCheckResult(region1Target, Fail, detail, fmt.Sprintf(actionDNSFail, region1Target, region1Target)),
newDNSCheckResult(region2Target, Fail, detail, fmt.Sprintf(actionDNSFail, region2Target, region2Target)),
return []ResolvedTarget{
{DNSResult: newDNSCheckResult(region1Target, Fail, detail, fmt.Sprintf(actionDNSFail, region1Target, region1Target))},
{DNSResult: newDNSCheckResult(region2Target, Fail, detail, fmt.Sprintf(actionDNSFail, region2Target, region2Target))},
}
}
targets := []string{region1Target, region2Target}
results := make([]CheckResult, 0, len(addrGroups))
for i, group := range addrGroups {
target := fmt.Sprintf("region%d.v2.argotunnel.com", i+1)
if i < len(targets) {
target = targets[i]
resolved := make([]ResolvedTarget, 0, len(addrGroups))
for i, target := range targets {
if i >= len(addrGroups) {
break
}
group := addrGroups[i]
if len(group) == 0 {
results = append(results, newDNSCheckResult(target, Fail, detailsNoAddressesReturned, fmt.Sprintf(actionDNSFail, target, target)))
resolved = append(resolved, ResolvedTarget{
DNSResult: newDNSCheckResult(target, Fail, dnsNoAddressesReturned, fmt.Sprintf(actionDNSFail, target, target)),
})
} else {
results = append(results, newDNSCheckResult(target, Pass, detailsResolvedSuccessfully, ""))
resolved = append(resolved, ResolvedTarget{
Addrs: group,
DNSResult: newDNSCheckResult(target, Pass, dnsResolvedSuccessfully, ""),
})
}
}
return addrGroups, results
return resolved
}
// probeQUIC performs a QUIC handshake to a single edge address and returns a
@@ -217,7 +210,7 @@ func probeQUIC(
Component: componentUDPConnectivity,
Target: targetPortQUIC,
ProbeStatus: Fail,
Details: detailsHandshakeFailed,
Details: detailsQUICHandshakeFailed,
Action: actionQUICBlocked,
}
}
@@ -231,7 +224,7 @@ func probeQUIC(
Component: componentUDPConnectivity,
Target: targetPortQUIC,
ProbeStatus: Pass,
Details: detailsHandshakeSuccessful,
Details: detailsQUICHandshakeSuccessful,
}
}
@@ -251,7 +244,7 @@ func probeHTTP2(ctx context.Context, tlsConfig *tls.Config, dialer TCPDialer, ad
Component: componentTCPConnectivity,
Target: targetPortHTTP2,
ProbeStatus: Fail,
Details: detailsBlockedOrUnreachable,
Details: detailsHTTP2BlockedOrUnreachable,
Action: actionHTTP2Blocked,
}
}
@@ -262,7 +255,7 @@ func probeHTTP2(ctx context.Context, tlsConfig *tls.Config, dialer TCPDialer, ad
Component: componentTCPConnectivity,
Target: targetPortHTTP2,
ProbeStatus: Pass,
Details: detailsTLSHandshakeSuccessful,
Details: detailsHTTP2HandshakeSuccessful,
}
}
@@ -281,7 +274,7 @@ func probeManagementAPI(ctx context.Context, dialer ManagementDialer) CheckResul
Component: componentCloudflareAPI,
Target: targetAPI,
ProbeStatus: Fail,
Details: detailsConnectionFailed,
Details: detailsAPIConnectionFailed,
Action: actionAPIUnreachable,
}
}
@@ -292,17 +285,17 @@ func probeManagementAPI(ctx context.Context, dialer ManagementDialer) CheckResul
Component: componentCloudflareAPI,
Target: targetAPI,
ProbeStatus: Pass,
Details: detailsTCPPortReachable,
Details: detailsApiReachable,
}
}
func skipResult(probeType ProbeType, component, target string) CheckResult {
func skipResult(probeType ProbeType, component, target string, details string) CheckResult {
return CheckResult{
Type: probeType,
Component: component,
Target: target,
ProbeStatus: Skip,
Details: detailsDNSPrerequisiteFailed,
Details: details,
}
}
@@ -345,3 +338,39 @@ func addrsByFamily(group []*allregions.EdgeAddr, ipVersion allregions.ConfigIPVe
}
return
}
// runDNSProbe runs probeDNS with retry and returns []ResolvedTarget.
func runDNSProbe(ctx context.Context, resolver DNSResolver, region string) []ResolvedTarget {
var targets []ResolvedTarget
withRetry(ctx, maxRetries, func() bool {
targets = probeDNS(resolver, region)
for _, t := range targets {
if t.DNSResult.ProbeStatus == Fail {
return false
}
}
return len(targets) > 0
})
return targets
}
// resolveStaticEdge resolves each --edge addr individually, returning one
// ResolvedTarget per addr. Unresolvable addrs produce a Fail ResolvedTarget
// with nil Addrs so the report shows which addresses could not be reached.
func resolveStaticEdge(addrs []string, log *zerolog.Logger) []ResolvedTarget {
targets := make([]ResolvedTarget, 0, len(addrs))
for _, addr := range addrs {
resolved := allregions.ResolveAddrs([]string{addr}, log)
if len(resolved) > 0 {
targets = append(targets, ResolvedTarget{
Addrs: resolved,
DNSResult: newDNSCheckResult(addr, Pass, dnsResolvedSuccessfully, ""),
})
} else {
targets = append(targets, ResolvedTarget{
DNSResult: newDNSCheckResult(addr, Fail, dnsNoAddressesReturned, fmt.Sprintf(actionDNSFail, addr, addr)),
})
}
}
return targets
}
+116 -117
View File
@@ -7,7 +7,6 @@ import (
"net"
"testing"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -33,67 +32,6 @@ const (
// perform a real TLS handshake, so an empty config is sufficient.
var testTLSConfig = &tls.Config{} //nolint:gosec
// mockQuicConnection is a minimal test double for quic.Connection.
type mockQuicConnection struct {
closeErr error
}
func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStream() (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) LocalAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) RemoteAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, _ string) error {
return m.closeErr
}
func (m *mockQuicConnection) Context() context.Context {
return context.Background()
}
func (m *mockQuicConnection) ConnectionState() quic.ConnectionState {
return quic.ConnectionState{}
}
func (m *mockQuicConnection) SendDatagram(_ []byte) error {
return nil
}
func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) {
return nil, nil
}
func (m *mockQuicConnection) AddPath(*quic.Transport) (*quic.Path, error) {
return nil, nil
}
// Helper to create test edge addresses.
func createTestEdgeAddr(ip string, port int, version allregions.EdgeIPVersion) *allregions.EdgeAddr {
parsedIP := net.ParseIP(ip)
@@ -117,15 +55,14 @@ func TestProbeDNS_Success(t *testing.T) {
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr, v6Addr}}, nil)
addrs, results := probeDNS(resolver, "")
targets := probeDNS(resolver, "")
require.NotNil(t, addrs)
require.Len(t, results, 1)
assert.Len(t, addrs, 1)
assert.Equal(t, ProbeTypeDNS, results[0].Type)
assert.Equal(t, testRegion1Global, results[0].Target)
assert.Equal(t, Pass, results[0].ProbeStatus)
assert.Equal(t, detailsResolvedSuccessfully, results[0].Details)
require.Len(t, targets, 1)
assert.NotEmpty(t, targets[0].Addrs)
assert.Equal(t, ProbeTypeDNS, targets[0].DNSResult.Type)
assert.Equal(t, testRegion1Global, targets[0].DNSResult.Target)
assert.Equal(t, Pass, targets[0].DNSResult.ProbeStatus)
assert.Equal(t, dnsResolvedSuccessfully, targets[0].DNSResult.Details)
}
func TestProbeDNS_MultipleRegions(t *testing.T) {
@@ -139,17 +76,17 @@ func TestProbeDNS_MultipleRegions(t *testing.T) {
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{v4Addr1}, {v4Addr2}}, nil)
addrs, results := probeDNS(resolver, "")
targets := probeDNS(resolver, "")
require.NotNil(t, addrs)
require.Len(t, results, 2)
assert.Len(t, addrs, 2)
require.Len(t, targets, 2)
assert.Equal(t, testRegion1Global, results[0].Target)
assert.Equal(t, Pass, results[0].ProbeStatus)
assert.Equal(t, testRegion1Global, targets[0].DNSResult.Target)
assert.Equal(t, Pass, targets[0].DNSResult.ProbeStatus)
assert.NotEmpty(t, targets[0].Addrs)
assert.Equal(t, testRegion2Global, results[1].Target)
assert.Equal(t, Pass, results[1].ProbeStatus)
assert.Equal(t, testRegion2Global, targets[1].DNSResult.Target)
assert.Equal(t, Pass, targets[1].DNSResult.ProbeStatus)
assert.NotEmpty(t, targets[1].Addrs)
}
func TestProbeDNS_ResolverError(t *testing.T) {
@@ -160,17 +97,16 @@ func TestProbeDNS_ResolverError(t *testing.T) {
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return(nil, errors.New("DNS lookup failed"))
addrs, results := probeDNS(resolver, "")
targets := probeDNS(resolver, "")
assert.Nil(t, addrs)
require.Len(t, results, 2)
assert.Equal(t, Fail, results[0].ProbeStatus)
assert.Equal(t, "DNS lookup failed", results[0].Details)
assert.Contains(t, results[0].Action, testRegion1Global)
assert.Contains(t, results[1].Action, testRegion2Global)
assert.Equal(t, Fail, results[1].ProbeStatus)
require.Len(t, targets, 2)
assert.Empty(t, targets[0].Addrs)
assert.Equal(t, Fail, targets[0].DNSResult.ProbeStatus)
assert.Equal(t, "DNS lookup failed", targets[0].DNSResult.Details)
assert.Contains(t, targets[0].DNSResult.Action, testRegion1Global)
assert.Empty(t, targets[1].Addrs)
assert.Equal(t, Fail, targets[1].DNSResult.ProbeStatus)
assert.Contains(t, targets[1].DNSResult.Action, testRegion2Global)
}
func TestProbeDNS_EmptyResults(t *testing.T) {
@@ -181,12 +117,12 @@ func TestProbeDNS_EmptyResults(t *testing.T) {
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{}, nil)
addrs, results := probeDNS(resolver, "")
targets := probeDNS(resolver, "")
assert.Nil(t, addrs)
require.Len(t, results, 2)
assert.Equal(t, Fail, results[0].ProbeStatus)
assert.Equal(t, "No addresses returned", results[0].Details)
require.Len(t, targets, 2)
assert.Empty(t, targets[0].Addrs)
assert.Equal(t, Fail, targets[0].DNSResult.ProbeStatus)
assert.Equal(t, dnsNoAddressesReturned, targets[0].DNSResult.Details)
}
func TestProbeDNS_EmptyGroup(t *testing.T) {
@@ -197,12 +133,12 @@ func TestProbeDNS_EmptyGroup(t *testing.T) {
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("").Return([][]*allregions.EdgeAddr{{}}, nil)
addrs, results := probeDNS(resolver, "")
targets := probeDNS(resolver, "")
require.NotNil(t, addrs)
require.Len(t, results, 1)
assert.Equal(t, Fail, results[0].ProbeStatus)
assert.Equal(t, "No addresses returned", results[0].Details)
require.Len(t, targets, 1)
assert.Empty(t, targets[0].Addrs)
assert.Equal(t, Fail, targets[0].DNSResult.ProbeStatus)
assert.Equal(t, dnsNoAddressesReturned, targets[0].DNSResult.Details)
}
func TestProbeDNS_RegionFlag(t *testing.T) {
@@ -214,10 +150,10 @@ func TestProbeDNS_RegionFlag(t *testing.T) {
resolver := mocks.NewMockDNSResolver(ctrl)
resolver.EXPECT().Resolve("us").Return([][]*allregions.EdgeAddr{{v4Addr}}, nil)
_, results := probeDNS(resolver, "us")
targets := probeDNS(resolver, "us")
require.Len(t, results, 1)
assert.Equal(t, testRegion1US, results[0].Target)
require.Len(t, targets, 1)
assert.Equal(t, testRegion1US, targets[0].DNSResult.Target)
}
// probeQUIC tests.
@@ -227,9 +163,10 @@ func TestProbeQUIC_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockConn := &mockQuicConnection{}
successfulQUICConn := mocks.NewMockQUICConnection(ctrl)
successfulQUICConn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Return(nil)
dialer := mocks.NewMockQUICDialer(ctrl)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(successfulQUICConn, nil)
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil)
@@ -238,7 +175,7 @@ func TestProbeQUIC_Success(t *testing.T) {
assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
assert.Equal(t, detailsQUICHandshakeSuccessful, result.Details)
}
func TestProbeQUIC_DialError(t *testing.T) {
@@ -256,7 +193,7 @@ func TestProbeQUIC_DialError(t *testing.T) {
assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsHandshakeFailed, result.Details)
assert.Equal(t, detailsQUICHandshakeFailed, result.Details)
assert.Equal(t, actionQUICBlocked, result.Action)
}
@@ -265,9 +202,12 @@ func TestProbeQUIC_CloseErrorDoesNotAffectResult(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockConn := &mockQuicConnection{closeErr: errors.New("close failed")}
// Return a mock whose CloseWithError returns an error — probeQUIC must still
// report Pass because the handshake itself succeeded.
fakeQUICConn := mocks.NewMockQUICConnection(ctrl)
fakeQUICConn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Return(errors.New("close failed"))
dialer := mocks.NewMockQUICDialer(ctrl)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeQUICConn, nil)
addr := createTestEdgeAddr("192.0.2.1", testEdgePort, allregions.V4)
logger := zerolog.New(nil)
@@ -276,7 +216,7 @@ func TestProbeQUIC_CloseErrorDoesNotAffectResult(t *testing.T) {
assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
assert.Equal(t, detailsQUICHandshakeSuccessful, result.Details)
}
func TestProbeQUIC_ContextTimeout(t *testing.T) {
@@ -293,7 +233,7 @@ func TestProbeQUIC_ContextTimeout(t *testing.T) {
result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsHandshakeFailed, result.Details)
assert.Equal(t, detailsQUICHandshakeFailed, result.Details)
}
// probeHTTP2 tests.
@@ -312,7 +252,7 @@ func TestProbeHTTP2_Success(t *testing.T) {
assert.Equal(t, ProbeTypeHTTP2, result.Type)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsTLSHandshakeSuccessful, result.Details)
assert.Equal(t, detailsHTTP2HandshakeSuccessful, result.Details)
}
func TestProbeHTTP2_DialError(t *testing.T) {
@@ -329,7 +269,7 @@ func TestProbeHTTP2_DialError(t *testing.T) {
assert.Equal(t, ProbeTypeHTTP2, result.Type)
assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsBlockedOrUnreachable, result.Details)
assert.Equal(t, detailsHTTP2BlockedOrUnreachable, result.Details)
assert.Equal(t, actionHTTP2Blocked, result.Action)
}
@@ -349,7 +289,7 @@ func TestProbeManagementAPI_Success(t *testing.T) {
assert.Equal(t, "Cloudflare API", result.Component)
assert.Equal(t, "api.cloudflare.com:443", result.Target)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsTCPPortReachable, result.Details)
assert.Equal(t, detailsApiReachable, result.Details)
}
func TestProbeManagementAPI_DialError(t *testing.T) {
@@ -364,7 +304,7 @@ func TestProbeManagementAPI_DialError(t *testing.T) {
assert.Equal(t, ProbeTypeManagementAPI, result.Type)
assert.Equal(t, Fail, result.ProbeStatus)
assert.Equal(t, detailsConnectionFailed, result.Details)
assert.Equal(t, detailsAPIConnectionFailed, result.Details)
assert.Equal(t, actionAPIUnreachable, result.Action)
}
@@ -373,7 +313,7 @@ func TestProbeManagementAPI_DialError(t *testing.T) {
func TestSkipResult(t *testing.T) {
t.Parallel()
result := skipResult(ProbeTypeQUIC, "UDP Connectivity", "Port 7844 (QUIC)")
result := skipResult(ProbeTypeQUIC, "UDP Connectivity", "Port 7844 (QUIC)", detailsDNSPrerequisiteFailed)
assert.Equal(t, ProbeTypeQUIC, result.Type)
assert.Equal(t, "UDP Connectivity", result.Component)
@@ -507,10 +447,11 @@ func TestProbeQUIC_IPv6Address(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
successfulQUICConn := mocks.NewMockQUICConnection(ctrl)
successfulQUICConn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Return(nil)
mockConn := &mockQuicConnection{}
dialer := mocks.NewMockQUICDialer(ctrl)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockConn, nil)
dialer.EXPECT().DialQuic(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(successfulQUICConn, nil)
addr := createTestEdgeAddr("2001:db8::1", testEdgePort, allregions.V6)
logger := zerolog.New(nil)
@@ -518,7 +459,7 @@ func TestProbeQUIC_IPv6Address(t *testing.T) {
result := probeQUIC(context.Background(), testTLSConfig, dialer, addr, &logger)
assert.Equal(t, Pass, result.ProbeStatus)
assert.Equal(t, detailsHandshakeSuccessful, result.Details)
assert.Equal(t, detailsQUICHandshakeSuccessful, result.Details)
}
// IPv6 address tests for probeHTTP2.
@@ -537,3 +478,61 @@ func TestProbeHTTP2_IPv6Address(t *testing.T) {
assert.Equal(t, Pass, result.ProbeStatus)
}
// resolveStaticEdge tests.
// TestResolveStaticEdge_SingleAddr verifies that a single resolvable --edge
// addr produces one group labeled with the original addr string.
func TestResolveStaticEdge_SingleAddr(t *testing.T) {
t.Parallel()
logger := zerolog.Nop()
targets := resolveStaticEdge([]string{"127.0.0.1:7844"}, &logger)
require.Len(t, targets, 1)
assert.Equal(t, "127.0.0.1:7844", targets[0].DNSResult.Target)
assert.Equal(t, Pass, targets[0].DNSResult.ProbeStatus)
assert.NotEmpty(t, targets[0].Addrs)
}
// TestResolveStaticEdge_MultipleAddrs verifies that multiple --edge addrs each
// produce their own ResolvedTarget, preserving per-addr structure and label order.
func TestResolveStaticEdge_MultipleAddrs(t *testing.T) {
t.Parallel()
logger := zerolog.Nop()
targets := resolveStaticEdge([]string{"127.0.0.1:7844", "127.0.0.2:7844"}, &logger)
require.Len(t, targets, 2)
assert.Equal(t, "127.0.0.1:7844", targets[0].DNSResult.Target)
assert.Equal(t, "127.0.0.2:7844", targets[1].DNSResult.Target)
}
// TestResolveStaticEdge_InvalidAddr verifies that an unresolvable addr is
// silently skipped and does not appear in the output.
func TestResolveStaticEdge_InvalidAddr(t *testing.T) {
t.Parallel()
logger := zerolog.Nop()
// "not-a-valid-addr" has no port — ResolveTCPAddr will fail.
targets := resolveStaticEdge([]string{"not-a-valid-addr"}, &logger)
require.Len(t, targets, 1)
assert.Equal(t, "not-a-valid-addr", targets[0].DNSResult.Target)
assert.Equal(t, Fail, targets[0].DNSResult.ProbeStatus)
assert.Equal(t, dnsNoAddressesReturned, targets[0].DNSResult.Details)
assert.Empty(t, targets[0].Addrs)
}
// TestResolveStaticEdge_PartiallyValid verifies that a mix of valid and invalid
// addrs produces one ResolvedTarget per addr — valid ones with Addrs and a Skip
// DNSResult, invalid ones with nil Addrs and a Fail DNSResult.
func TestResolveStaticEdge_PartiallyValid(t *testing.T) {
t.Parallel()
logger := zerolog.Nop()
targets := resolveStaticEdge([]string{"127.0.0.1:7844", "not-a-valid-addr", "127.0.0.2:7844"}, &logger)
require.Len(t, targets, 3)
assert.Equal(t, "127.0.0.1:7844", targets[0].DNSResult.Target)
assert.Equal(t, Pass, targets[0].DNSResult.ProbeStatus)
assert.NotEmpty(t, targets[0].Addrs)
assert.Equal(t, "not-a-valid-addr", targets[1].DNSResult.Target)
assert.Equal(t, Fail, targets[1].DNSResult.ProbeStatus)
assert.Empty(t, targets[1].Addrs)
assert.Equal(t, "127.0.0.2:7844", targets[2].DNSResult.Target)
assert.Equal(t, Pass, targets[2].DNSResult.ProbeStatus)
assert.NotEmpty(t, targets[2].Addrs)
}
+2 -2
View File
@@ -11,8 +11,8 @@ import (
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection/dialopts"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
cfdquic "github.com/cloudflare/cloudflared/quic"
)
// DNSResolver abstracts edge DNS discovery used by DNS probes.
@@ -44,7 +44,7 @@ type QUICDialer interface {
connIndex uint8,
logger *zerolog.Logger,
opts dialopts.DialOpts,
) (quic.Connection, error)
) (cfdquic.QUICConnection, error)
}
// ManagementDialer abstracts the TCP dial to api.cloudflare.com:443 used by
+14 -51
View File
@@ -10,19 +10,12 @@ import (
)
const (
// tableWidth is the total character width of the separator lines.
tableWidth = 80
// Status names.
passStatus = "PASS"
failStatus = "FAIL"
skipStatus = "SKIP"
unknownStatus = "UNKNOWN"
// Section separators.
sectionChar = "-"
headerTitle = "CONNECTIVITY PRE-CHECKS"
// Log message constants.
logMsgPrecheck = "precheck"
logMsgPrecheckComplete = "precheck complete"
@@ -35,8 +28,6 @@ const (
logFieldDetails = "details"
logFieldHardFail = "hard_fail"
logFieldSuggestedProtocol = "suggested_protocol"
sep = " "
)
// statusLabel returns the display label for a given Status.
@@ -58,21 +49,9 @@ func (s Status) logString() string {
return strings.ToLower(s.String())
}
// separator returns a full-width horizontal line.
func separator() string {
return strings.Repeat(sectionChar, tableWidth)
}
// header returns the top section title line.
func header() string {
leftDashes := strings.Repeat(sectionChar, 3)
rightLen := tableWidth - len(leftDashes) - len(headerTitle) - len(sep)*2
return leftDashes + sep + headerTitle + sep + strings.Repeat(sectionChar, rightLen)
}
// renderTable uses text/tabwriter to format the results rows with
// automatically aligned columns, returning the rendered string.
func renderTable(results []CheckResult) string {
// automatically aligned columns, returning the rendered lines.
func renderTable(results []CheckResult) []string {
var buf bytes.Buffer
// minwidth=0, tabwidth=8, padding=2, padchar=' ', flags=0
w := tabwriter.NewWriter(&buf, 0, 8, 2, ' ', 0)
@@ -81,27 +60,27 @@ func renderTable(results []CheckResult) string {
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", r.Component, r.Target, r.ProbeStatus.statusLabel(), r.Details)
}
_ = w.Flush()
return buf.String()
return strings.Split(strings.TrimSuffix(buf.String(), "\n"), "\n")
}
// renderActions collects all non-empty Action strings from results and returns
// the formatted warning/error block that appears between the table and SUMMARY.
// A Fail result is rendered as ERROR when the report is a hard fail, and as
// WARNING otherwise (degraded but tunnel can still run).
func renderActions(r Report) string {
func renderActions(r Report) []string {
hardFail := r.hasHardFail()
var sb strings.Builder
actions := make([]string, 0)
for _, res := range r.Results {
if res.Action == "" || res.ProbeStatus != Fail {
continue
}
if hardFail {
_, _ = fmt.Fprintf(&sb, "ERROR: %s\n", res.Action)
actions = append(actions, fmt.Sprintf("ERROR: %s", res.Action))
} else {
_, _ = fmt.Fprintf(&sb, "WARNING: %s\n", res.Action)
actions = append(actions, fmt.Sprintf("WARNING: %s", res.Action))
}
}
return sb.String()
return actions
}
// summaryLine builds the SUMMARY: line based on the Report state.
@@ -181,28 +160,12 @@ func (r Report) hasWarn() bool {
return (quicFail != http2Fail) || apiFail
}
// String renders the Report as a human-readable table suitable for os.Stdout.
func (r Report) String() string {
var sb strings.Builder
sb.WriteString(header())
sb.WriteString("\n")
sb.WriteString(renderTable(r.Results))
actions := renderActions(r)
if actions != "" {
sb.WriteString(actions)
}
sb.WriteString("\n")
sb.WriteString(summaryLine(r))
sb.WriteString("\n")
sb.WriteString(separator())
sb.WriteString("\n")
return sb.String()
// String renders the Report as human-readable table lines suitable for logging.
func (r Report) String() []string {
lines := renderTable(r.Results)
lines = append(lines, renderActions(r)...)
lines = append(lines, "", summaryLine(r))
return lines
}
// LogEvent emits each CheckResult as a structured zerolog log line, followed by
+85 -90
View File
@@ -25,11 +25,11 @@ func allPassReport() Report {
RunID: fixedRunID,
SuggestedProtocol: new(connection.QUIC),
Results: []CheckResult{
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region1.v2.argotunnel.com", ProbeStatus: Pass, Details: "Resolved successfully"},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region2.v2.argotunnel.com", ProbeStatus: Pass, Details: "Resolved successfully"},
{Type: ProbeTypeQUIC, Component: "UDP Connectivity", Target: "Port 7844 (QUIC)", ProbeStatus: Pass, Details: "Handshake successful"},
{Type: ProbeTypeHTTP2, Component: "TCP Connectivity", Target: "Port 7844 (HTTP/2)", ProbeStatus: Pass, Details: "TLS handshake successful"},
{Type: ProbeTypeManagementAPI, Component: "Cloudflare API", Target: "api.cloudflare.com:443", ProbeStatus: Pass, Details: "Reachable"},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region1.v2.argotunnel.com", ProbeStatus: Pass, Details: dnsResolvedSuccessfully},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region2.v2.argotunnel.com", ProbeStatus: Pass, Details: dnsResolvedSuccessfully},
{Type: ProbeTypeQUIC, Component: "UDP Connectivity", Target: "Port 7844 (QUIC)", ProbeStatus: Pass, Details: detailsQUICHandshakeSuccessful},
{Type: ProbeTypeHTTP2, Component: "TCP Connectivity", Target: "Port 7844 (HTTP/2)", ProbeStatus: Pass, Details: detailsHTTP2HandshakeSuccessful},
{Type: ProbeTypeManagementAPI, Component: "Cloudflare API", Target: "api.cloudflare.com:443", ProbeStatus: Pass, Details: detailsApiReachable},
},
}
}
@@ -41,18 +41,18 @@ func quicBlockedReport() Report {
RunID: fixedRunID,
SuggestedProtocol: new(connection.HTTP2),
Results: []CheckResult{
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region1.v2.argotunnel.com", ProbeStatus: Pass, Details: "Resolved successfully"},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region2.v2.argotunnel.com", ProbeStatus: Pass, Details: "Resolved successfully"},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region1.v2.argotunnel.com", ProbeStatus: Pass, Details: dnsResolvedSuccessfully},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region2.v2.argotunnel.com", ProbeStatus: Pass, Details: dnsResolvedSuccessfully},
{
Type: ProbeTypeQUIC,
Component: "UDP Connectivity",
Target: "Port 7844 (QUIC)",
ProbeStatus: Fail,
Details: "Handshake failed",
Action: "Allow outbound QUIC on port 7844. cloudflared will use http2 in the meantime.",
Details: detailsQUICHandshakeFailed,
Action: actionQUICBlocked,
},
{Type: ProbeTypeHTTP2, Component: "TCP Connectivity", Target: "Port 7844 (HTTP/2)", ProbeStatus: Pass, Details: "TLS handshake successful"},
{Type: ProbeTypeManagementAPI, Component: "Cloudflare API", Target: "api.cloudflare.com:443", ProbeStatus: Pass, Details: "Reachable"},
{Type: ProbeTypeHTTP2, Component: "TCP Connectivity", Target: "Port 7844 (HTTP/2)", ProbeStatus: Pass, Details: detailsHTTP2HandshakeSuccessful},
{Type: ProbeTypeManagementAPI, Component: "Cloudflare API", Target: "api.cloudflare.com:443", ProbeStatus: Pass, Details: detailsApiReachable},
},
}
}
@@ -64,10 +64,10 @@ func apiFailReport() Report {
RunID: fixedRunID,
SuggestedProtocol: new(connection.QUIC),
Results: []CheckResult{
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region1.v2.argotunnel.com", ProbeStatus: Pass, Details: "Resolved successfully"},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region2.v2.argotunnel.com", ProbeStatus: Pass, Details: "Resolved successfully"},
{Type: ProbeTypeQUIC, Component: "UDP Connectivity", Target: "Port 7844 (QUIC)", ProbeStatus: Pass, Details: "Handshake successful"},
{Type: ProbeTypeHTTP2, Component: "TCP Connectivity", Target: "Port 7844 (HTTP/2)", ProbeStatus: Pass, Details: "TLS handshake successful"},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region1.v2.argotunnel.com", ProbeStatus: Pass, Details: dnsResolvedSuccessfully},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region2.v2.argotunnel.com", ProbeStatus: Pass, Details: dnsResolvedSuccessfully},
{Type: ProbeTypeQUIC, Component: "UDP Connectivity", Target: "Port 7844 (QUIC)", ProbeStatus: Pass, Details: detailsQUICHandshakeSuccessful},
{Type: ProbeTypeHTTP2, Component: "TCP Connectivity", Target: "Port 7844 (HTTP/2)", ProbeStatus: Pass, Details: detailsHTTP2HandshakeSuccessful},
{
Type: ProbeTypeManagementAPI,
Component: "Cloudflare API",
@@ -86,14 +86,14 @@ func bothTransportsBlockedReport() Report {
RunID: fixedRunID,
SuggestedProtocol: nil,
Results: []CheckResult{
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region1.v2.argotunnel.com", ProbeStatus: Pass, Details: "Resolved successfully"},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region2.v2.argotunnel.com", ProbeStatus: Pass, Details: "Resolved successfully"},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region1.v2.argotunnel.com", ProbeStatus: Pass, Details: dnsResolvedSuccessfully},
{Type: ProbeTypeDNS, Component: "DNS Resolution", Target: "region2.v2.argotunnel.com", ProbeStatus: Pass, Details: dnsResolvedSuccessfully},
{
Type: ProbeTypeQUIC,
Component: "UDP Connectivity",
Target: "Port 7844 (QUIC)",
ProbeStatus: Fail,
Details: "Handshake failed",
Details: detailsQUICHandshakeFailed,
Action: "Allow outbound QUIC and/or TCP on port 7844 to the Cloudflare edge.",
},
{
@@ -101,9 +101,9 @@ func bothTransportsBlockedReport() Report {
Component: "TCP Connectivity",
Target: "Port 7844 (HTTP/2)",
ProbeStatus: Fail,
Details: "Blocked or unreachable",
Details: detailsHTTP2BlockedOrUnreachable,
},
{Type: ProbeTypeManagementAPI, Component: "Cloudflare API", Target: "api.cloudflare.com:443", ProbeStatus: Pass, Details: "Reachable"},
{Type: ProbeTypeManagementAPI, Component: "Cloudflare API", Target: "api.cloudflare.com:443", ProbeStatus: Pass, Details: detailsApiReachable},
},
}
}
@@ -134,85 +134,80 @@ func dnsFailReport() Report {
func TestString_AllPass(t *testing.T) {
t.Parallel()
want := "" +
"--- CONNECTIVITY PRE-CHECKS ----------------------------------------------------\n" +
"COMPONENT TARGET STATUS DETAILS\n" +
"DNS Resolution region1.v2.argotunnel.com PASS Resolved successfully\n" +
"DNS Resolution region2.v2.argotunnel.com PASS Resolved successfully\n" +
"UDP Connectivity Port 7844 (QUIC) PASS Handshake successful\n" +
"TCP Connectivity Port 7844 (HTTP/2) PASS TLS handshake successful\n" +
"Cloudflare API api.cloudflare.com:443 PASS Reachable\n" +
"\n" +
"SUMMARY: Environment is healthy. cloudflared will use 'quic' as primary protocol.\n" +
"--------------------------------------------------------------------------------\n"
want := []string{
"COMPONENT TARGET STATUS DETAILS",
"DNS Resolution region1.v2.argotunnel.com PASS DNS Resolved successfully",
"DNS Resolution region2.v2.argotunnel.com PASS DNS Resolved successfully",
"UDP Connectivity Port 7844 (QUIC) PASS QUIC connection successful",
"TCP Connectivity Port 7844 (HTTP/2) PASS HTTP/2 connection successful",
"Cloudflare API api.cloudflare.com:443 PASS API is reachable",
"",
"SUMMARY: Environment is healthy. cloudflared will use 'quic' as primary protocol.",
}
assert.Equal(t, want, allPassReport().String())
}
func TestString_QuicBlocked(t *testing.T) {
t.Parallel()
want := "" +
"--- CONNECTIVITY PRE-CHECKS ----------------------------------------------------\n" +
"COMPONENT TARGET STATUS DETAILS\n" +
"DNS Resolution region1.v2.argotunnel.com PASS Resolved successfully\n" +
"DNS Resolution region2.v2.argotunnel.com PASS Resolved successfully\n" +
"UDP Connectivity Port 7844 (QUIC) FAIL Handshake failed\n" +
"TCP Connectivity Port 7844 (HTTP/2) PASS TLS handshake successful\n" +
"Cloudflare API api.cloudflare.com:443 PASS Reachable\n" +
"WARNING: Allow outbound QUIC on port 7844. cloudflared will use http2 in the meantime.\n" +
"\n" +
"SUMMARY: Environment ready with degraded transport. cloudflared will proceed using 'http2'.\n" +
"--------------------------------------------------------------------------------\n"
want := []string{
"COMPONENT TARGET STATUS DETAILS",
"DNS Resolution region1.v2.argotunnel.com PASS DNS Resolved successfully",
"DNS Resolution region2.v2.argotunnel.com PASS DNS Resolved successfully",
"UDP Connectivity Port 7844 (QUIC) FAIL QUIC connection failed",
"TCP Connectivity Port 7844 (HTTP/2) PASS HTTP/2 connection successful",
"Cloudflare API api.cloudflare.com:443 PASS API is reachable",
"WARNING: Allow outbound QUIC traffic on port 7844 or use HTTP2.",
"",
"SUMMARY: Environment ready with degraded transport. cloudflared will proceed using 'http2'.",
}
assert.Equal(t, want, quicBlockedReport().String())
}
func TestString_APIFail(t *testing.T) {
t.Parallel()
want := "" +
"--- CONNECTIVITY PRE-CHECKS ----------------------------------------------------\n" +
"COMPONENT TARGET STATUS DETAILS\n" +
"DNS Resolution region1.v2.argotunnel.com PASS Resolved successfully\n" +
"DNS Resolution region2.v2.argotunnel.com PASS Resolved successfully\n" +
"UDP Connectivity Port 7844 (QUIC) PASS Handshake successful\n" +
"TCP Connectivity Port 7844 (HTTP/2) PASS TLS handshake successful\n" +
"Cloudflare API api.cloudflare.com:443 FAIL Connection refused\n" +
"WARNING: cloudflared will still run, but automatic software updates are unavailable. Ensure port 443 TCP to api.cloudflare.com is open if you want auto-updates.\n" +
"\n" +
"SUMMARY: Environment ready with degraded transport. cloudflared will proceed using 'quic'.\n" +
"--------------------------------------------------------------------------------\n"
want := []string{
"COMPONENT TARGET STATUS DETAILS",
"DNS Resolution region1.v2.argotunnel.com PASS DNS Resolved successfully",
"DNS Resolution region2.v2.argotunnel.com PASS DNS Resolved successfully",
"UDP Connectivity Port 7844 (QUIC) PASS QUIC connection successful",
"TCP Connectivity Port 7844 (HTTP/2) PASS HTTP/2 connection successful",
"Cloudflare API api.cloudflare.com:443 FAIL Connection refused",
"WARNING: cloudflared will still run, but automatic software updates are unavailable. Ensure port 443 TCP to api.cloudflare.com is open if you want auto-updates.",
"",
"SUMMARY: Environment ready with degraded transport. cloudflared will proceed using 'quic'.",
}
assert.Equal(t, want, apiFailReport().String())
}
func TestString_BothTransportsBlocked(t *testing.T) {
t.Parallel()
want := "" +
"--- CONNECTIVITY PRE-CHECKS ----------------------------------------------------\n" +
"COMPONENT TARGET STATUS DETAILS\n" +
"DNS Resolution region1.v2.argotunnel.com PASS Resolved successfully\n" +
"DNS Resolution region2.v2.argotunnel.com PASS Resolved successfully\n" +
"UDP Connectivity Port 7844 (QUIC) FAIL Handshake failed\n" +
"TCP Connectivity Port 7844 (HTTP/2) FAIL Blocked or unreachable\n" +
"Cloudflare API api.cloudflare.com:443 PASS Reachable\n" +
"ERROR: Allow outbound QUIC and/or TCP on port 7844 to the Cloudflare edge.\n" +
"\n" +
"SUMMARY: Environment has critical failures. cloudflared may not be able to establish a tunnel.\n" +
"--------------------------------------------------------------------------------\n"
want := []string{
"COMPONENT TARGET STATUS DETAILS",
"DNS Resolution region1.v2.argotunnel.com PASS DNS Resolved successfully",
"DNS Resolution region2.v2.argotunnel.com PASS DNS Resolved successfully",
"UDP Connectivity Port 7844 (QUIC) FAIL QUIC connection failed",
"TCP Connectivity Port 7844 (HTTP/2) FAIL HTTP/2 connection is blocked or unreachable",
"Cloudflare API api.cloudflare.com:443 PASS API is reachable",
"ERROR: Allow outbound QUIC and/or TCP on port 7844 to the Cloudflare edge.",
"",
"SUMMARY: Environment has critical failures. cloudflared may not be able to establish a tunnel.",
}
assert.Equal(t, want, bothTransportsBlockedReport().String())
}
func TestString_DNSFail(t *testing.T) {
t.Parallel()
want := "" +
"--- CONNECTIVITY PRE-CHECKS ----------------------------------------------------\n" +
"COMPONENT TARGET STATUS DETAILS\n" +
"DNS Resolution region1.v2.argotunnel.com FAIL No addresses returned\n" +
"DNS Resolution region2.v2.argotunnel.com FAIL No addresses returned\n" +
"UDP Connectivity Port 7844 (QUIC) SKIP DNS prerequisite failed\n" +
"TCP Connectivity Port 7844 (HTTP/2) SKIP DNS prerequisite failed\n" +
"Cloudflare API api.cloudflare.com:443 FAIL Connection refused\n" +
"ERROR: Ensure your DNS resolver can resolve 'region1.v2.argotunnel.com'. Run: dig A region1.v2.argotunnel.com @1.1.1.1. If that fails, contact your network administrator.\n" +
"\n" +
"SUMMARY: Environment has critical failures. cloudflared may not be able to establish a tunnel.\n" +
"--------------------------------------------------------------------------------\n"
want := []string{
"COMPONENT TARGET STATUS DETAILS",
"DNS Resolution region1.v2.argotunnel.com FAIL No addresses returned",
"DNS Resolution region2.v2.argotunnel.com FAIL No addresses returned",
"UDP Connectivity Port 7844 (QUIC) SKIP DNS prerequisite failed",
"TCP Connectivity Port 7844 (HTTP/2) SKIP DNS prerequisite failed",
"Cloudflare API api.cloudflare.com:443 FAIL Connection refused",
"ERROR: Ensure your DNS resolver can resolve 'region1.v2.argotunnel.com'. Run: dig A region1.v2.argotunnel.com @1.1.1.1. If that fails, contact your network administrator.",
"",
"SUMMARY: Environment has critical failures. cloudflared may not be able to establish a tunnel.",
}
assert.Equal(t, want, dnsFailReport().String())
}
@@ -221,9 +216,9 @@ func TestString_EmptyResults(t *testing.T) {
r := Report{RunID: fixedRunID, SuggestedProtocol: new(connection.QUIC)}
out := r.String()
// Must not panic and must still emit a valid skeleton.
assert.Contains(t, out, "CONNECTIVITY PRE-CHECKS")
assert.Contains(t, out, "SUMMARY:")
assert.Contains(t, out, separator())
require.Len(t, out, 3)
assert.Contains(t, out[0], "COMPONENT")
assert.Contains(t, out[2], "SUMMARY:")
}
// LogEvent() / structured log renderer tests
@@ -276,11 +271,11 @@ func TestLogEvent_AllPass(t *testing.T) {
status string
details string
}{
{"DNS Resolution", "region1.v2.argotunnel.com", "pass", "Resolved successfully"},
{"DNS Resolution", "region2.v2.argotunnel.com", "pass", "Resolved successfully"},
{"UDP Connectivity", "Port 7844 (QUIC)", "pass", "Handshake successful"},
{"TCP Connectivity", "Port 7844 (HTTP/2)", "pass", "TLS handshake successful"},
{"Cloudflare API", "api.cloudflare.com:443", "pass", "Reachable"},
{"DNS Resolution", "region1.v2.argotunnel.com", "pass", dnsResolvedSuccessfully},
{"DNS Resolution", "region2.v2.argotunnel.com", "pass", dnsResolvedSuccessfully},
{"UDP Connectivity", "Port 7844 (QUIC)", "pass", detailsQUICHandshakeSuccessful},
{"TCP Connectivity", "Port 7844 (HTTP/2)", "pass", detailsHTTP2HandshakeSuccessful},
{"Cloudflare API", "api.cloudflare.com:443", "pass", detailsApiReachable},
}
for i, exp := range expected {
e := entries[i]
@@ -312,7 +307,7 @@ func TestLogEvent_QuicBlocked(t *testing.T) {
assert.Equal(t, "fail", quic.Status)
assert.Equal(t, "UDP Connectivity", quic.Component)
assert.Equal(t, "Port 7844 (QUIC)", quic.Target)
assert.Equal(t, "Handshake failed", quic.Details)
assert.Equal(t, "QUIC connection failed", quic.Details)
assert.Equal(t, fixedRunID.String(), quic.RunID)
// Summary: not a hard fail (HTTP/2 still works), protocol falls back to http2.
@@ -354,9 +349,9 @@ func TestLogEvent_BothTransportsBlocked(t *testing.T) {
// Both transport rows carry status=fail.
assert.Equal(t, "fail", entries[2].Status)
assert.Equal(t, "Handshake failed", entries[2].Details)
assert.Equal(t, "QUIC connection failed", entries[2].Details)
assert.Equal(t, "fail", entries[3].Status)
assert.Equal(t, "Blocked or unreachable", entries[3].Details)
assert.Equal(t, "HTTP/2 connection is blocked or unreachable", entries[3].Details)
summary := entries[len(entries)-1]
require.NotNil(t, summary.HardFail)
+28
View File
@@ -74,6 +74,19 @@ type CheckResult struct {
Action string
}
// ResolvedTarget bundles a resolved edge target's addresses with the DNS
// CheckResult that describes it. This keeps addr groups and their report rows
// together as a single unit, avoiding parallel-slice synchronization.
type ResolvedTarget struct {
// Addrs holds the resolved edge addresses for this target. May be empty
// when DNS resolution succeeded structurally but returned no IPs.
Addrs []*allregions.EdgeAddr
// DNSResult is the CheckResult representing DNS resolution for this target.
// Its Target field is the human-readable label used across all probe rows.
DNSResult CheckResult
}
// Report aggregates all CheckResults produced by a single Run() invocation.
// Pre-checks run in parallel with tunnel initialization and are purely
// diagnostic: the Report is displayed to the user but never gates startup.
@@ -107,4 +120,19 @@ type Config struct {
// checks. It mirrors the --edge-ip-version CLI flag so that the pre-check
// exercises the same code paths the tunnel itself will use.
IPVersion allregions.ConfigIPVersion
// EdgeAddrs, when non-empty, contains the --edge flag values (explicit
// edge addresses). When set, DNS probing is skipped entirely — there are
// no SRV records to validate — and transport probes target each addr
// individually, labeled with the original addr string.
EdgeAddrs []string
// ProtocolOverride is the raw --protocol flag value (e.g. "quic",
// "http2", "h2mux"). When non-empty and not "auto", the pre-checks still
// probe both transports for diagnostic completeness, but the reported
// SuggestedProtocol honours the override so that the summary message
// reflects what cloudflared will actually use — not what the probe
// heuristic would recommend on its own. Parsing happens inside the
// prechecks package.
ProtocolOverride string
}
+8 -8
View File
@@ -137,7 +137,7 @@ func (p *Proxy) ProxyHTTP(
p.proxyLocalRequest(originProxy, w, req, isWebsocket)
return nil
default:
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
return fmt.Errorf("unrecognized service: %s, %t", rule.Service, originProxy)
}
}
@@ -193,7 +193,7 @@ func (p *Proxy) proxyHTTPRequest(
) error {
roundTripReq := tr.Request
if isWebsocket {
roundTripReq = tr.Clone(tr.Request.Context())
roundTripReq = tr.Clone(tr.Context())
roundTripReq.Header.Set("Connection", "Upgrade")
roundTripReq.Header.Set("Upgrade", "websocket")
roundTripReq.Header.Set("Sec-Websocket-Version", "13")
@@ -203,7 +203,7 @@ func (p *Proxy) proxyHTTPRequest(
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if disableChunkedEncoding {
roundTripReq.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(tr.Request.Header.Get("Content-Length"))
cLength, err := strconv.Atoi(tr.Header.Get("Content-Length"))
if err == nil {
roundTripReq.ContentLength = int64(cLength)
}
@@ -228,7 +228,7 @@ func (p *Proxy) proxyHTTPRequest(
}
tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode)
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
headers := make(http.Header, len(resp.Header))
// copy headers
@@ -249,11 +249,11 @@ func (p *Proxy) proxyHTTPRequest(
if !ok {
return errors.New("internal error: unsupported connection type")
}
defer rwc.Close()
defer func() { _ = rwc.Close() }()
eyeballStream := &bidirectionalStream{
writer: w,
reader: tr.Request.Body,
reader: tr.Body,
}
stream.Pipe(eyeballStream, rwc, logger)
@@ -292,7 +292,7 @@ func (p *Proxy) proxyStream(
return err
}
connectSpan.End()
defer originConn.Close()
defer func() { _ = originConn.Close() }()
logger.Debug().Msg("origin connection established")
encodedSpans := tr.GetSpans()
@@ -331,7 +331,7 @@ func (p *Proxy) proxyTCPStream(
return err
}
connectSpan.End()
defer originConn.Close()
defer func() { _ = originConn.Close() }()
logger.Debug().Msg("origin connection established")
encodedSpans := tr.GetSpans()
+52 -52
View File
@@ -4,85 +4,85 @@ import (
"strconv"
"time"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
)
// Helper to convert logging.ByteCount(alias for int64) to float64 used in prometheus
func byteCountToPromCount(count logging.ByteCount) float64 {
// byteCountToPromCount converts an int64 byte count to float64 used in prometheus.
func byteCountToPromCount(count int64) float64 {
return float64(count)
}
// Helper to convert Duration to float64 used in prometheus
// durationToPromGauge converts a Duration to float64 milliseconds used in prometheus.
func durationToPromGauge(duration time.Duration) float64 {
return float64(duration.Milliseconds())
}
// Helper to convert https://pkg.go.dev/github.com/quic-go/quic-go@v0.23.0/logging#PacketType into string
func packetTypeString(pt logging.PacketType) string {
// packetTypeString converts a qlog.PacketType to a Prometheus-safe label string.
// The allowlist prevents unbounded cardinality if upstream adds new values.
func packetTypeString(pt qlog.PacketType) string {
switch pt {
case logging.PacketTypeInitial:
return "initial"
case logging.PacketTypeHandshake:
return "handshake"
case logging.PacketTypeRetry:
return "retry"
case logging.PacketType0RTT:
return "0_rtt"
case logging.PacketTypeVersionNegotiation:
return "version_negotiation"
case logging.PacketType1RTT:
return "1_rtt"
case logging.PacketTypeStatelessReset:
return "stateless_reset"
case logging.PacketTypeNotDetermined:
return "undetermined"
case qlog.PacketTypeInitial,
qlog.PacketTypeHandshake,
qlog.PacketType0RTT,
qlog.PacketType1RTT,
qlog.PacketTypeRetry,
qlog.PacketTypeVersionNegotiation,
qlog.PacketTypeStatelessReset:
return string(pt)
default:
return "unknown_packet_type"
}
}
// Helper to convert https://pkg.go.dev/github.com/quic-go/quic-go@v0.23.0/logging#PacketDropReason into string
func packetDropReasonString(reason logging.PacketDropReason) string {
// packetDropReasonString converts a qlog.PacketDropReason to a Prometheus-safe label string.
// The allowlist passes known values through and guards against unbounded cardinality.
func packetDropReasonString(reason qlog.PacketDropReason) string {
switch reason {
case logging.PacketDropKeyUnavailable:
return "key_unavailable"
case logging.PacketDropUnknownConnectionID:
return "unknown_conn_id"
case logging.PacketDropHeaderParseError:
return "header_parse_err"
case logging.PacketDropPayloadDecryptError:
return "payload_decrypt_err"
case logging.PacketDropProtocolViolation:
return "protocol_violation"
case logging.PacketDropDOSPrevention:
return "dos_prevention"
case logging.PacketDropUnsupportedVersion:
return "unsupported_version"
case logging.PacketDropUnexpectedPacket:
return "unexpected_packet"
case logging.PacketDropUnexpectedSourceConnectionID:
return "unexpected_src_conn_id"
case logging.PacketDropUnexpectedVersion:
return "unexpected_version"
case logging.PacketDropDuplicate:
return "duplicate"
case qlog.PacketDropKeyUnavailable,
qlog.PacketDropUnknownConnectionID,
qlog.PacketDropHeaderParseError,
qlog.PacketDropPayloadDecryptError,
qlog.PacketDropProtocolViolation,
qlog.PacketDropDOSPrevention,
qlog.PacketDropUnsupportedVersion,
qlog.PacketDropUnexpectedPacket,
qlog.PacketDropUnexpectedSourceConnectionID,
qlog.PacketDropUnexpectedVersion,
qlog.PacketDropDuplicate:
return string(reason)
default:
return "unknown_reason"
}
}
// Helper to convert https://pkg.go.dev/github.com/quic-go/quic-go@v0.23.0/logging#PacketLossReason into string
func packetLossReasonString(reason logging.PacketLossReason) string {
// packetLossReasonString converts a qlog.PacketLossReason to a Prometheus-safe label string.
func packetLossReasonString(reason qlog.PacketLossReason) string {
switch reason {
case logging.PacketLossReorderingThreshold:
return "reordering"
case logging.PacketLossTimeThreshold:
return "timeout"
case qlog.PacketLossReorderingThreshold,
qlog.PacketLossTimeThreshold:
return string(reason)
default:
return "unknown_loss_reason"
}
}
// congestionStateToFloat maps a qlog.CongestionState string to a numeric value for prometheus gauges.
// Mapping: slow_start=0, congestion_avoidance=1, application_limited=2, recovery=3, unknown=-1.
func congestionStateToFloat(state qlog.CongestionState) float64 {
switch state {
case qlog.CongestionStateSlowStart:
return 0
case qlog.CongestionStateCongestionAvoidance:
return 1
case qlog.CongestionStateApplicationLimited:
return 2
case qlog.CongestionStateRecovery:
return 3
default:
return -1
}
}
func uint8ToString(input uint8) string {
return strconv.FormatUint(uint64(input), 10)
}
+2 -3
View File
@@ -6,7 +6,6 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/packet"
@@ -25,12 +24,12 @@ type BaseDatagramMuxer interface {
}
type DatagramMuxer struct {
session quic.Connection
session QUICConnection
logger *zerolog.Logger
demuxChan chan<- *packet.Session
}
func NewDatagramMuxer(quicSession quic.Connection, log *zerolog.Logger, demuxChan chan<- *packet.Session) *DatagramMuxer {
func NewDatagramMuxer(quicSession QUICConnection, log *zerolog.Logger, demuxChan chan<- *packet.Session) *DatagramMuxer {
logger := log.With().Uint8("datagramVersion", 1).Logger()
return &DatagramMuxer{
session: quicSession,
+12 -8
View File
@@ -5,7 +5,6 @@ import (
"fmt"
"github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/packet"
@@ -51,14 +50,14 @@ func (dm *DatagramMuxerV2) mtu() int {
}
type DatagramMuxerV2 struct {
session quic.Connection
session QUICConnection
logger *zerolog.Logger
sessionDemuxChan chan<- *packet.Session
packetDemuxChan chan Packet
}
func NewDatagramMuxerV2(
quicSession quic.Connection,
quicSession QUICConnection,
log *zerolog.Logger,
sessionDemuxChan chan<- *packet.Session,
) *DatagramMuxerV2 {
@@ -110,7 +109,8 @@ func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
return nil
}
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
// ServeReceive reads datagrams from the QUIC connection and demuxes them
// depending on whether it's a session or packet
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
for {
msg, err := dm.session.ReceiveDatagram(ctx)
@@ -141,11 +141,13 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error
}
msgType := DatagramV2Type(msgWithType[len(msgWithType)-typeIDLen])
msg := msgWithType[0 : len(msgWithType)-typeIDLen]
switch msgType {
switch msgType { //nolint:exhaustive // default handles all non-UDP types via handlePacket
case DatagramTypeUDP:
return dm.handleSession(ctx, msg)
default:
case DatagramTypeIP, DatagramTypeIPWithTrace, DatagramTypeTracingSpan:
return dm.handlePacket(ctx, msg, msgType)
default:
return fmt.Errorf("unexpected datagram type %d", msgType)
}
}
@@ -168,7 +170,7 @@ func (dm *DatagramMuxerV2) handleSession(ctx context.Context, session []byte) er
func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType DatagramV2Type) error {
var demuxedPacket Packet
switch msgType {
switch msgType { //nolint:exhaustive // DatagramTypeUDP is handled by the caller (demux)
case DatagramTypeIP:
demuxedPacket = RawPacket(packet.RawPacket{Data: pk})
case DatagramTypeIPWithTrace:
@@ -189,8 +191,10 @@ func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType
Spans: spans,
TracingIdentity: tracingIdentity,
}
case DatagramTypeUDP:
return fmt.Errorf("unexpected datagram type %d in handlePacket", msgType)
default:
return fmt.Errorf("Unexpected datagram type %d", msgType)
return fmt.Errorf("unexpected datagram type %d", msgType)
}
select {
case <-ctx.Done():
+50 -29
View File
@@ -4,9 +4,10 @@ import (
"reflect"
"strings"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
"github.com/rs/zerolog"
)
@@ -175,7 +176,7 @@ var (
Namespace: namespace,
Subsystem: "client",
Name: "congestion_state",
Help: "Current congestion control state. See https://pkg.go.dev/github.com/quic-go/quic-go@v0.45.0/logging#CongestionState for what each value maps to",
Help: "Current congestion control state (0=slow_start, 1=congestion_avoidance, 2=application_limited, 3=recovery, -1=unknown)",
},
[]string{ConnectionIndexMetricLabel},
),
@@ -229,28 +230,37 @@ func (cc *clientCollector) startedConnection() {
clientMetrics.totalConnections.Inc()
}
func (cc *clientCollector) closedConnection(error) {
func (cc *clientCollector) closedConnection() {
clientMetrics.closedConnections.Inc()
}
func (cc *clientCollector) receivedTransportParameters(params *logging.TransportParameters) {
clientMetrics.maxUDPPayloadSize.WithLabelValues(cc.index).Set(float64(params.MaxUDPPayloadSize))
cc.logger.Debug().Msgf("Received transport parameters: MaxUDPPayloadSize=%d, MaxIdleTimeout=%v, MaxDatagramFrameSize=%d", params.MaxUDPPayloadSize, params.MaxIdleTimeout, params.MaxDatagramFrameSize)
// receivedTransportParameters records metrics from the peer's transport parameters.
func (cc *clientCollector) receivedTransportParameters(maxUDPPayloadSize int64, maxIdleTimeout time.Duration, maxDatagramFrameSize int64) {
clientMetrics.maxUDPPayloadSize.WithLabelValues(cc.index).Set(float64(maxUDPPayloadSize))
cc.logger.
Debug().
Int64("MaxUDPPayloadSize", maxUDPPayloadSize).
Dur("MaxIdleTimeout", maxIdleTimeout).
Int64("MaxDatagramFrameSize", maxDatagramFrameSize).Msgf("Received transport parameters")
}
func (cc *clientCollector) sentPackets(size logging.ByteCount, frames []logging.Frame) {
// sentPackets records metrics for sent packets.
func (cc *clientCollector) sentPackets(size int64, frames []qlog.Frame) {
cc.collectPackets(size, frames, clientMetrics.sentFrames, clientMetrics.sentBytes, sent)
}
func (cc *clientCollector) receivedPackets(size logging.ByteCount, frames []logging.Frame) {
// receivedPackets records metrics for received packets.
func (cc *clientCollector) receivedPackets(size int64, frames []qlog.Frame) {
cc.collectPackets(size, frames, clientMetrics.receivedFrames, clientMetrics.receivedBytes, received)
}
func (cc *clientCollector) bufferedPackets(packetType logging.PacketType) {
// bufferedPackets records metrics for buffered packets.
func (cc *clientCollector) bufferedPackets(packetType qlog.PacketType) {
clientMetrics.bufferedPackets.WithLabelValues(cc.index, packetTypeString(packetType)).Inc()
}
func (cc *clientCollector) droppedPackets(packetType logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) {
// droppedPackets records metrics for dropped packets.
func (cc *clientCollector) droppedPackets(packetType qlog.PacketType, size int64, reason qlog.PacketDropReason) {
clientMetrics.droppedPackets.WithLabelValues(
cc.index,
packetTypeString(packetType),
@@ -258,35 +268,43 @@ func (cc *clientCollector) droppedPackets(packetType logging.PacketType, size lo
).Add(byteCountToPromCount(size))
}
func (cc *clientCollector) lostPackets(reason logging.PacketLossReason) {
// lostPackets records metrics for lost packets.
func (cc *clientCollector) lostPackets(reason qlog.PacketLossReason) {
clientMetrics.lostPackets.WithLabelValues(cc.index, packetLossReasonString(reason)).Inc()
}
func (cc *clientCollector) updatedRTT(rtt *logging.RTTStats) {
clientMetrics.minRTT.WithLabelValues(cc.index).Set(durationToPromGauge(rtt.MinRTT()))
clientMetrics.latestRTT.WithLabelValues(cc.index).Set(durationToPromGauge(rtt.LatestRTT()))
clientMetrics.smoothedRTT.WithLabelValues(cc.index).Set(durationToPromGauge(rtt.SmoothedRTT()))
// updatedRTT records RTT metrics.
func (cc *clientCollector) updatedRTT(m qlog.MetricsUpdated) {
clientMetrics.minRTT.WithLabelValues(cc.index).Set(durationToPromGauge(m.MinRTT))
clientMetrics.latestRTT.WithLabelValues(cc.index).Set(durationToPromGauge(m.LatestRTT))
clientMetrics.smoothedRTT.WithLabelValues(cc.index).Set(durationToPromGauge(m.SmoothedRTT))
}
func (cc *clientCollector) updateCongestionWindow(size logging.ByteCount) {
// updateCongestionWindow records the congestion window size.
func (cc *clientCollector) updateCongestionWindow(size int64) {
clientMetrics.congestionWindow.WithLabelValues(cc.index).Set(float64(size))
}
func (cc *clientCollector) updatedCongestionState(state logging.CongestionState) {
clientMetrics.congestionState.WithLabelValues(cc.index).Set(float64(state))
// updatedCongestionState records the congestion control state.
func (cc *clientCollector) updatedCongestionState(state qlog.CongestionState) {
clientMetrics.congestionState.WithLabelValues(cc.index).Set(congestionStateToFloat(state))
}
func (cc *clientCollector) updateMTU(mtu logging.ByteCount) {
// updateMTU records the MTU value.
func (cc *clientCollector) updateMTU(mtu int64) {
clientMetrics.mtu.WithLabelValues(cc.index).Set(float64(mtu))
cc.logger.Debug().Msgf("QUIC MTU updated to %d", mtu)
}
func (cc *clientCollector) collectPackets(size logging.ByteCount, frames []logging.Frame, counter, bandwidth *prometheus.CounterVec, direction direction) {
// collectPackets is the shared implementation for sentPackets and receivedPackets.
func (cc *clientCollector) collectPackets(size int64, frames []qlog.Frame, counter, bandwidth *prometheus.CounterVec, direction direction) {
for _, frame := range frames {
switch f := frame.(type) {
case logging.DataBlockedFrame:
cc.logger.Debug().Msgf("%s data_blocked frame", direction)
case logging.StreamDataBlockedFrame:
// qlog.Frame.Frame holds the concrete wire frame type as any.
// The quic-go encoder always stores pointers (*wire.XxxFrame).
switch f := frame.Frame.(type) {
case *qlog.DataBlockedFrame:
cc.logger.Debug().Int64("limit", int64(f.MaximumData)).Msgf("%s data_blocked frame", direction)
case *qlog.StreamDataBlockedFrame:
cc.logger.Debug().Int64("streamID", int64(f.StreamID)).Msgf("%s stream_data_blocked frame", direction)
}
counter.WithLabelValues(cc.index, frameName(frame)).Inc()
@@ -294,13 +312,16 @@ func (cc *clientCollector) collectPackets(size logging.ByteCount, frames []loggi
bandwidth.WithLabelValues(cc.index).Add(byteCountToPromCount(size))
}
func frameName(frame logging.Frame) string {
if frame == nil {
// frameName extracts the type name from a qlog.Frame for use as a Prometheus label.
func frameName(frame qlog.Frame) string {
if frame.Frame == nil {
return "nil"
} else {
name := reflect.TypeOf(frame).Elem().Name()
return strings.TrimSuffix(name, "Frame")
}
t := reflect.TypeOf(frame.Frame)
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
return strings.TrimSuffix(t.Name(), "Frame")
}
type direction uint8
+104
View File
@@ -0,0 +1,104 @@
package quic
import (
"context"
"errors"
"io"
"net"
"github.com/quic-go/quic-go"
)
// QUICConnection defines the subset of [quic.Connection] methods used by cloudflared.
// Consumers should accept this interface; producers should return [*ConnWithCloser].
type QUICConnection interface {
AcceptStream(ctx context.Context) (*quic.Stream, error)
OpenStream() (*quic.Stream, error)
OpenStreamSync(ctx context.Context) (*quic.Stream, error)
CloseWithError(code quic.ApplicationErrorCode, reason string) error
Context() context.Context
SendDatagram(payload []byte) error
ReceiveDatagram(ctx context.Context) ([]byte, error)
LocalAddr() net.Addr
RemoteAddr() net.Addr
ConnectionState() quic.ConnectionState
}
// Compile-time assertion that *ConnWithCloser implements QUICConnection.
var _ QUICConnection = (*ConnWithCloser)(nil)
var (
// error returned when the [NewQUICConnection] is called with a nil conn argument
ErrNilQuicConnection = errors.New("the provided quic connection is nil")
// error returned when the [NewQUICConnection] is called with a nil closer argument
ErrNilCloser = errors.New("the provided closer is nil")
)
// ConnWithCloser wraps a [quic.Connection] and an [io.Closer] (typically the
// underlying [*net.UDPConn]). When [CloseWithError] is called the QUIC
// connection is closed first, then the closer is closed deterministically.
//
// All fields are non-nil after successful construction via [NewQUICConnection].
type ConnWithCloser struct {
conn *quic.Conn
closer io.Closer
}
// NewQUICConnection returns a [*ConnWithCloser] that will close closer after
// the QUIC connection is closed.
func NewQUICConnection(conn *quic.Conn, closer io.Closer) (*ConnWithCloser, error) {
if conn == nil {
return nil, ErrNilQuicConnection
}
if closer == nil {
return nil, ErrNilCloser
}
return &ConnWithCloser{conn: conn, closer: closer}, nil
}
// CloseWithError closes the QUIC connection and then closes the underlying
// [io.Closer]. If both operations return errors, the errors are joined so that
// the closer error is no longer silently discarded.
func (c *ConnWithCloser) CloseWithError(code quic.ApplicationErrorCode, reason string) error {
connErr := c.conn.CloseWithError(code, reason)
closerErr := c.closer.Close()
return errors.Join(connErr, closerErr)
}
func (c *ConnWithCloser) AcceptStream(ctx context.Context) (*quic.Stream, error) {
return c.conn.AcceptStream(ctx)
}
func (c *ConnWithCloser) OpenStream() (*quic.Stream, error) {
return c.conn.OpenStream()
}
func (c *ConnWithCloser) OpenStreamSync(ctx context.Context) (*quic.Stream, error) {
return c.conn.OpenStreamSync(ctx)
}
func (c *ConnWithCloser) Context() context.Context {
return c.conn.Context()
}
func (c *ConnWithCloser) SendDatagram(payload []byte) error {
return c.conn.SendDatagram(payload)
}
func (c *ConnWithCloser) ReceiveDatagram(ctx context.Context) ([]byte, error) {
return c.conn.ReceiveDatagram(ctx)
}
func (c *ConnWithCloser) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *ConnWithCloser) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *ConnWithCloser) ConnectionState() quic.ConnectionState {
return c.conn.ConnectionState()
}
+31
View File
@@ -0,0 +1,31 @@
package quic
import (
"testing"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/require"
)
// mockCloser is an [io.Closer] that returns a configurable error.
type mockCloser struct {
closeErr error
}
func (m *mockCloser) Close() error {
return m.closeErr
}
func TestNewConnWithCloser_NilConn(t *testing.T) {
t.Parallel()
conn, err := NewQUICConnection(nil, &mockCloser{})
require.ErrorIs(t, err, ErrNilQuicConnection)
require.Nil(t, conn)
}
func TestNewConnWithCloser_NilCloser(t *testing.T) {
t.Parallel()
conn, err := NewQUICConnection(&quic.Conn{}, nil)
require.ErrorIs(t, err, ErrNilCloser)
require.Nil(t, conn)
}
+2 -2
View File
@@ -17,13 +17,13 @@ var idleTimeoutError = quic.IdleTimeoutError{}
type SafeStreamCloser struct {
lock sync.Mutex
stream quic.Stream
stream *quic.Stream
writeTimeout time.Duration
log *zerolog.Logger
closing atomic.Bool
}
func NewSafeStreamCloser(stream quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser {
func NewSafeStreamCloser(stream *quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser {
return &SafeStreamCloser{
stream: stream,
writeTimeout: writeTimeout,
+79 -52
View File
@@ -2,19 +2,20 @@ package quic
import (
"context"
"net"
"time"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
"github.com/quic-go/quic-go/qlogwriter"
"github.com/rs/zerolog"
)
// QUICTracer is a wrapper to create new quicConnTracer
// tracer builds a connTracer for each new QUIC connection.
type tracer struct {
index string
logger *zerolog.Logger
}
func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context, logging.Perspective, logging.ConnectionID) *logging.ConnectionTracer {
func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context, bool, qlog.ConnectionID) qlogwriter.Trace {
t := &tracer{
index: uint8ToString(index),
logger: logger,
@@ -22,85 +23,111 @@ func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context,
return t.TracerForConnection
}
func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspective, _odcid logging.ConnectionID) *logging.ConnectionTracer {
// TracerForConnection returns a qlogwriter.Trace for a new connection.
func (t *tracer) TracerForConnection(_ context.Context, _ bool, _ qlog.ConnectionID) qlogwriter.Trace {
return newConnTracer(newClientCollector(t.index, t.logger))
}
// connTracer collects connection level metrics
// connTracer collects connection level metrics. It implements
// qlogwriter.Trace + qlogwriter.Recorder and dispatches qlog events to the
// metric-collection methods via RecordEvent.
type connTracer struct {
metricsCollector *clientCollector
}
func newConnTracer(metricsCollector *clientCollector) *logging.ConnectionTracer {
tracer := connTracer{
func newConnTracer(metricsCollector *clientCollector) *connTracer {
return &connTracer{
metricsCollector: metricsCollector,
}
return &logging.ConnectionTracer{
StartedConnection: tracer.StartedConnection,
ClosedConnection: tracer.ClosedConnection,
ReceivedTransportParameters: tracer.ReceivedTransportParameters,
SentLongHeaderPacket: tracer.SentLongHeaderPacket,
SentShortHeaderPacket: tracer.SentShortHeaderPacket,
ReceivedLongHeaderPacket: tracer.ReceivedLongHeaderPacket,
ReceivedShortHeaderPacket: tracer.ReceivedShortHeaderPacket,
BufferedPacket: tracer.BufferedPacket,
DroppedPacket: tracer.DroppedPacket,
UpdatedMetrics: tracer.UpdatedMetrics,
LostPacket: tracer.LostPacket,
UpdatedMTU: tracer.UpdatedMTU,
UpdatedCongestionState: tracer.UpdatedCongestionState,
}
}
func (ct *connTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) {
func (ct *connTracer) AddProducer() qlogwriter.Recorder {
// connTracer is both the Trace and the Recorder: each connection gets
// exactly one producer that routes events to the collector methods below.
return ct
}
func (ct *connTracer) SupportsSchemas(_ string) bool {
return true
}
// RecordEvent dispatches qlog events to the collector methods.
func (ct *connTracer) RecordEvent(ev qlogwriter.Event) {
switch e := ev.(type) {
case qlog.StartedConnection:
ct.StartedConnection()
case qlog.ConnectionClosed:
ct.ClosedConnection()
case qlog.ParametersSet:
// ParametersSet fires for both local and remote; filter to remote only
// via the Initiator field.
if e.Initiator == qlog.InitiatorRemote {
ct.ReceivedTransportParameters(int64(e.MaxUDPPayloadSize), e.MaxIdleTimeout, int64(e.MaxDatagramFrameSize))
}
case qlog.PacketSent:
ct.SentPacket(int64(e.Raw.Length), e.Frames)
case qlog.PacketReceived:
ct.ReceivedPacket(int64(e.Raw.Length), e.Frames)
case qlog.PacketBuffered:
ct.BufferedPacket(e.Header.PacketType)
case qlog.PacketDropped:
ct.DroppedPacket(e.Header.PacketType, int64(e.Raw.Length), e.Trigger)
case qlog.PacketLost:
ct.LostPacket(e.Trigger)
case qlog.MetricsUpdated:
ct.UpdatedMetrics(e)
case qlog.MTUUpdated:
ct.UpdatedMTU(int64(e.Value))
case qlog.CongestionStateUpdated:
ct.UpdatedCongestionState(e.State)
}
}
func (ct *connTracer) Close() error {
return nil
}
func (ct *connTracer) StartedConnection() {
ct.metricsCollector.startedConnection()
}
func (ct *connTracer) ClosedConnection(err error) {
ct.metricsCollector.closedConnection(err)
func (ct *connTracer) ClosedConnection() {
ct.metricsCollector.closedConnection()
}
func (ct *connTracer) ReceivedTransportParameters(params *logging.TransportParameters) {
ct.metricsCollector.receivedTransportParameters(params)
func (ct *connTracer) ReceivedTransportParameters(maxUDPPayloadSize int64, maxIdleTimeout time.Duration, maxDatagramFrameSize int64) {
ct.metricsCollector.receivedTransportParameters(maxUDPPayloadSize, maxIdleTimeout, maxDatagramFrameSize)
}
func (ct *connTracer) BufferedPacket(pt logging.PacketType, size logging.ByteCount) {
func (ct *connTracer) SentPacket(size int64, frames []qlog.Frame) {
ct.metricsCollector.sentPackets(size, frames)
}
func (ct *connTracer) ReceivedPacket(size int64, frames []qlog.Frame) {
ct.metricsCollector.receivedPackets(size, frames)
}
func (ct *connTracer) BufferedPacket(pt qlog.PacketType) {
ct.metricsCollector.bufferedPackets(pt)
}
func (ct *connTracer) DroppedPacket(pt logging.PacketType, number logging.PacketNumber, size logging.ByteCount, reason logging.PacketDropReason) {
func (ct *connTracer) DroppedPacket(pt qlog.PacketType, size int64, reason qlog.PacketDropReason) {
ct.metricsCollector.droppedPackets(pt, size, reason)
}
func (ct *connTracer) LostPacket(level logging.EncryptionLevel, number logging.PacketNumber, reason logging.PacketLossReason) {
func (ct *connTracer) LostPacket(reason qlog.PacketLossReason) {
ct.metricsCollector.lostPackets(reason)
}
func (ct *connTracer) UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) {
ct.metricsCollector.updatedRTT(rttStats)
ct.metricsCollector.updateCongestionWindow(cwnd)
func (ct *connTracer) UpdatedMetrics(m qlog.MetricsUpdated) {
ct.metricsCollector.updatedRTT(m)
ct.metricsCollector.updateCongestionWindow(int64(m.CongestionWindow))
}
func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
ct.metricsCollector.sentPackets(size, frames)
}
func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
ct.metricsCollector.sentPackets(size, frames)
}
func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
ct.metricsCollector.receivedPackets(size, frames)
}
func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
ct.metricsCollector.receivedPackets(size, frames)
}
func (ct *connTracer) UpdatedMTU(mtu logging.ByteCount, done bool) {
func (ct *connTracer) UpdatedMTU(mtu int64) {
ct.metricsCollector.updateMTU(mtu)
}
func (ct *connTracer) UpdatedCongestionState(state logging.CongestionState) {
func (ct *connTracer) UpdatedCongestionState(state qlog.CongestionState) {
ct.metricsCollector.updatedCongestionState(state)
}
+36
View File
@@ -0,0 +1,36 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:recommended",
"schedule:nonOfficeHours"
],
"enabledManagers": [
"dockerfile"
],
"dockerfile": {
"managerFilePatterns": [
"/(^|/)Dockerfile\\.amd64$/",
"/(^|/)Dockerfile\\.arm64$/"
]
},
"packageRules": [
{
"description": "Disable updates for everything by default; only the distroless base image is managed for now",
"matchPackageNames": [
"*"
],
"enabled": false
},
{
"description": "Keep any distroless base image up to date by pinning and updating its digest, since tags like :nonroot are rolling tags without a semver version",
"matchManagers": [
"dockerfile"
],
"matchPackageNames": [
"gcr.io/distroless/**"
],
"enabled": true,
"pinDigests": true
}
]
}
+4 -5
View File
@@ -67,7 +67,6 @@ func (s *bidirectionalStreamStatus) wait(maxWaitForSecondStream time.Duration) e
// Only wait for second stream to finish if maxWait is greater than zero
if maxWaitForSecondStream > 0 {
timer := time.NewTimer(maxWaitForSecondStream)
defer timer.Stop()
@@ -87,14 +86,14 @@ func (s *bidirectionalStreamStatus) isAnyDone() bool {
// Pipe copies copy data to & from provided io.ReadWriters.
func Pipe(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
PipeBidirectional(NopCloseWriterAdapter(tunnelConn), NopCloseWriterAdapter(originConn), 0, log)
_ = PipeBidirectional(NopCloseWriterAdapter(tunnelConn), NopCloseWriterAdapter(originConn), 0, log)
}
// PipeBidirectional copies data two BidirectionStreams. It is a special case of Pipe where it receives a concept that allows for Read and Write side to be closed independently.
// PipeBidirectional copies data to two unidirectional streams. It is a special case of Pipe where it receives a concept that allows for Read and Write side to be closed independently.
// The main difference is that when piping data from a reader to a writer, if EOF is read, then this implementation propagates the EOF signal to the destination/writer by closing the write side of the
// Bidirectional Stream.
// Finally, depending on once EOF is ready from one of the provided streams, the other direction of streaming data will have a configured time period to also finish, otherwise,
// the method will return immediately with a timeout error. It is however, the responsability of the caller to close the associated streams in both ends in order to free all the resources/go-routines.
// the method will return immediately with a timeout error. It is however, the responsibility of the caller to close the associated streams in both ends in order to free all the resources/go-routines.
func PipeBidirectional(downstream, upstream Stream, maxWaitForSecondStream time.Duration, log *zerolog.Logger) error {
status := newBiStreamStatus()
@@ -129,7 +128,7 @@ func unidirectionalStream(dst WriterCloser, src Reader, dir string, status *bidi
}
}()
defer dst.CloseWrite()
defer func() { _ = dst.CloseWrite() }()
_, err := copyData(dst, src, dir)
if err != nil {
-3
View File
@@ -66,9 +66,6 @@ type TunnelConfig struct {
// NoPrechecks disables connectivity pre-checks at startup.
NoPrechecks bool
// Prechecks enables connectivity pre-checks at startup.
Prechecks bool
NamedTunnel *connection.TunnelProperties
ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config
+2 -2
View File
@@ -407,7 +407,7 @@ func GetAppInfo(reqURL *url.URL) (*AppInfo, error) {
func handleRedirects(req *http.Request, via []*http.Request, orgToken string) error {
// attach org token to login request
if strings.Contains(req.URL.Path, AccessLoginWorkerPath) {
req.AddCookie(&http.Cookie{Name: tokenCookie, Value: orgToken})
req.AddCookie(&http.Cookie{Name: tokenCookie, Value: orgToken}) //nolint: gosec
}
// attach app session cookie to authorized request
@@ -417,7 +417,7 @@ func handleRedirects(req *http.Request, via []*http.Request, orgToken string) er
if prevReq != nil && prevReq.Response != nil {
for _, c := range prevReq.Response.Cookies() {
if c.Name == appSessionCookie {
req.AddCookie(&http.Cookie{Name: appSessionCookie, Value: c.Value})
req.AddCookie(&http.Cookie{Name: appSessionCookie, Value: c.Value}) //nolint: gosec
return nil
}
}
+9 -1
View File
@@ -414,6 +414,9 @@ func (ctx ecKeyGenerator) genKey() ([]byte, rawHeader, error) {
// Decrypt the given payload and return the content encryption key.
func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
if recipient == nil {
return nil, errors.New("go-jose/go-jose: missing recipient")
}
epk, err := headers.getEPK()
if err != nil {
return nil, errors.New("go-jose/go-jose: invalid epk header")
@@ -461,13 +464,18 @@ func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientI
return nil, ErrUnsupportedAlgorithm
}
encryptedKey := recipient.encryptedKey
if len(encryptedKey) == 0 {
return nil, errors.New("go-jose/go-jose: missing JWE Encrypted Key")
}
key := deriveKey(string(algorithm), keySize)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return josecipher.KeyUnwrap(block, recipient.encryptedKey)
return josecipher.KeyUnwrap(block, encryptedKey)
}
func (ctx edDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
+9 -1
View File
@@ -66,12 +66,20 @@ func KeyWrap(block cipher.Block, cek []byte) ([]byte, error) {
}
// KeyUnwrap implements NIST key unwrapping; it unwraps a content encryption key (cek) with the given block cipher.
//
// https://datatracker.ietf.org/doc/html/rfc7518#section-4.4
// https://datatracker.ietf.org/doc/html/rfc7518#section-4.6
// https://datatracker.ietf.org/doc/html/rfc7518#section-4.8
func KeyUnwrap(block cipher.Block, ciphertext []byte) ([]byte, error) {
n := (len(ciphertext) / 8) - 1
if n <= 0 {
return nil, errors.New("go-jose/go-jose: JWE Encrypted Key too short")
}
if len(ciphertext)%8 != 0 {
return nil, errors.New("go-jose/go-jose: key wrap input must be 8 byte blocks")
}
n := (len(ciphertext) / 8) - 1
r := make([][]byte, n)
for i := range r {
+18 -8
View File
@@ -366,11 +366,21 @@ func (ctx *symmetricKeyCipher) encryptKey(cek []byte, alg KeyAlgorithm) (recipie
// Decrypt the content encryption key.
func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
switch headers.getAlgorithm() {
case DIRECT:
cek := make([]byte, len(ctx.key))
copy(cek, ctx.key)
return cek, nil
if recipient == nil {
return nil, fmt.Errorf("go-jose/go-jose: missing recipient")
}
alg := headers.getAlgorithm()
if alg == DIRECT {
return bytes.Clone(ctx.key), nil
}
encryptedKey := recipient.encryptedKey
if len(encryptedKey) == 0 {
return nil, fmt.Errorf("go-jose/go-jose: missing JWE Encrypted Key")
}
switch alg {
case A128GCMKW, A192GCMKW, A256GCMKW:
aead := newAESGCM(len(ctx.key))
@@ -385,7 +395,7 @@ func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipien
parts := &aeadParts{
iv: iv.bytes(),
ciphertext: recipient.encryptedKey,
ciphertext: encryptedKey,
tag: tag.bytes(),
}
@@ -401,7 +411,7 @@ func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipien
return nil, err
}
cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
cek, err := josecipher.KeyUnwrap(block, encryptedKey)
if err != nil {
return nil, err
}
@@ -445,7 +455,7 @@ func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipien
return nil, err
}
cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
cek, err := josecipher.KeyUnwrap(block, encryptedKey)
if err != nil {
return nil, err
}
-14
View File
@@ -1,14 +0,0 @@
# editorconfig.org
root = true
[*]
insert_final_newline = true
charset = utf-8
trim_trailing_whitespace = true
indent_style = tab
indent_size = 8
[*.{md,yml,yaml,json}]
indent_style = space
indent_size = 2
-1
View File
@@ -1 +0,0 @@
* text=auto
-2
View File
@@ -1,2 +0,0 @@
vendor/
/.glide
-383
View File
@@ -1,383 +0,0 @@
# Changelog
## Release 3.2.3 (2022-11-29)
### Changed
- Updated docs (thanks @book987 @aJetHorn @neelayu @pellizzetti @apricote @SaigyoujiYuyuko233 @AlekSi)
- #348: Updated huandu/xstrings which fixed a snake case bug (thanks @yxxhero)
- #353: Updated masterminds/semver which included bug fixes
- #354: Updated golang.org/x/crypto which included bug fixes
## Release 3.2.2 (2021-02-04)
This is a re-release of 3.2.1 to satisfy something with the Go module system.
## Release 3.2.1 (2021-02-04)
### Changed
- Upgraded `Masterminds/goutils` to `v1.1.1`. see the [Security Advisory](https://github.com/Masterminds/goutils/security/advisories/GHSA-xg2h-wx96-xgxr)
## Release 3.2.0 (2020-12-14)
### Added
- #211: Added randInt function (thanks @kochurovro)
- #223: Added fromJson and mustFromJson functions (thanks @mholt)
- #242: Added a bcrypt function (thanks @robbiet480)
- #253: Added randBytes function (thanks @MikaelSmith)
- #254: Added dig function for dicts (thanks @nyarly)
- #257: Added regexQuoteMeta for quoting regex metadata (thanks @rheaton)
- #261: Added filepath functions osBase, osDir, osExt, osClean, osIsAbs (thanks @zugl)
- #268: Added and and all functions for testing conditions (thanks @phuslu)
- #181: Added float64 arithmetic addf, add1f, subf, divf, mulf, maxf, and minf
(thanks @andrewmostello)
- #265: Added chunk function to split array into smaller arrays (thanks @karelbilek)
- #270: Extend certificate functions to handle non-RSA keys + add support for
ed25519 keys (thanks @misberner)
### Changed
- Removed testing and support for Go 1.12. ed25519 support requires Go 1.13 or newer
- Using semver 3.1.1 and mergo 0.3.11
### Fixed
- #249: Fix htmlDateInZone example (thanks @spawnia)
NOTE: The dependency github.com/imdario/mergo reverted the breaking change in
0.3.9 via 0.3.10 release.
## Release 3.1.0 (2020-04-16)
NOTE: The dependency github.com/imdario/mergo made a behavior change in 0.3.9
that impacts sprig functionality. Do not use sprig with a version newer than 0.3.8.
### Added
- #225: Added support for generating htpasswd hash (thanks @rustycl0ck)
- #224: Added duration filter (thanks @frebib)
- #205: Added `seq` function (thanks @thadc23)
### Changed
- #203: Unlambda functions with correct signature (thanks @muesli)
- #236: Updated the license formatting for GitHub display purposes
- #238: Updated package dependency versions. Note, mergo not updated to 0.3.9
as it causes a breaking change for sprig. That issue is tracked at
https://github.com/imdario/mergo/issues/139
### Fixed
- #229: Fix `seq` example in docs (thanks @kalmant)
## Release 3.0.2 (2019-12-13)
### Fixed
- #220: Updating to semver v3.0.3 to fix issue with <= ranges
- #218: fix typo elyptical->elliptic in ecdsa key description (thanks @laverya)
## Release 3.0.1 (2019-12-08)
### Fixed
- #212: Updated semver fixing broken constraint checking with ^0.0
## Release 3.0.0 (2019-10-02)
### Added
- #187: Added durationRound function (thanks @yjp20)
- #189: Added numerous template functions that return errors rather than panic (thanks @nrvnrvn)
- #193: Added toRawJson support (thanks @Dean-Coakley)
- #197: Added get support to dicts (thanks @Dean-Coakley)
### Changed
- #186: Moving dependency management to Go modules
- #186: Updated semver to v3. This has changes in the way ^ is handled
- #194: Updated documentation on merging and how it copies. Added example using deepCopy
- #196: trunc now supports negative values (thanks @Dean-Coakley)
## Release 2.22.0 (2019-10-02)
### Added
- #173: Added getHostByName function to resolve dns names to ips (thanks @fcgravalos)
- #195: Added deepCopy function for use with dicts
### Changed
- Updated merge and mergeOverwrite documentation to explain copying and how to
use deepCopy with it
## Release 2.21.0 (2019-09-18)
### Added
- #122: Added encryptAES/decryptAES functions (thanks @n0madic)
- #128: Added toDecimal support (thanks @Dean-Coakley)
- #169: Added list contcat (thanks @astorath)
- #174: Added deepEqual function (thanks @bonifaido)
- #170: Added url parse and join functions (thanks @astorath)
### Changed
- #171: Updated glide config for Google UUID to v1 and to add ranges to semver and testify
### Fixed
- #172: Fix semver wildcard example (thanks @piepmatz)
- #175: Fix dateInZone doc example (thanks @s3than)
## Release 2.20.0 (2019-06-18)
### Added
- #164: Adding function to get unix epoch for a time (@mattfarina)
- #166: Adding tests for date_in_zone (@mattfarina)
### Changed
- #144: Fix function comments based on best practices from Effective Go (@CodeLingoTeam)
- #150: Handles pointer type for time.Time in "htmlDate" (@mapreal19)
- #161, #157, #160, #153, #158, #156, #155, #159, #152 documentation updates (@badeadan)
### Fixed
## Release 2.19.0 (2019-03-02)
IMPORTANT: This release reverts a change from 2.18.0
In the previous release (2.18), we prematurely merged a partial change to the crypto functions that led to creating two sets of crypto functions (I blame @technosophos -- since that's me). This release rolls back that change, and does what was originally intended: It alters the existing crypto functions to use secure random.
We debated whether this classifies as a change worthy of major revision, but given the proximity to the last release, we have decided that treating 2.18 as a faulty release is the correct course of action. We apologize for any inconvenience.
### Changed
- Fix substr panic 35fb796 (Alexey igrychev)
- Remove extra period 1eb7729 (Matthew Lorimor)
- Make random string functions use crypto by default 6ceff26 (Matthew Lorimor)
- README edits/fixes/suggestions 08fe136 (Lauri Apple)
## Release 2.18.0 (2019-02-12)
### Added
- Added mergeOverwrite function
- cryptographic functions that use secure random (see fe1de12)
### Changed
- Improve documentation of regexMatch function, resolves #139 90b89ce (Jan Tagscherer)
- Handle has for nil list 9c10885 (Daniel Cohen)
- Document behaviour of mergeOverwrite fe0dbe9 (Lukas Rieder)
- doc: adds missing documentation. 4b871e6 (Fernandez Ludovic)
- Replace outdated goutils imports 01893d2 (Matthew Lorimor)
- Surface crypto secure random strings from goutils fe1de12 (Matthew Lorimor)
- Handle untyped nil values as paramters to string functions 2b2ec8f (Morten Torkildsen)
### Fixed
- Fix dict merge issue and provide mergeOverwrite .dst .src1 to overwrite from src -> dst 4c59c12 (Lukas Rieder)
- Fix substr var names and comments d581f80 (Dean Coakley)
- Fix substr documentation 2737203 (Dean Coakley)
## Release 2.17.1 (2019-01-03)
### Fixed
The 2.17.0 release did not have a version pinned for xstrings, which caused compilation failures when xstrings < 1.2 was used. This adds the correct version string to glide.yaml.
## Release 2.17.0 (2019-01-03)
### Added
- adds alder32sum function and test 6908fc2 (marshallford)
- Added kebabcase function ca331a1 (Ilyes512)
### Changed
- Update goutils to 1.1.0 4e1125d (Matt Butcher)
### Fixed
- Fix 'has' documentation e3f2a85 (dean-coakley)
- docs(dict): fix typo in pick example dc424f9 (Dustin Specker)
- fixes spelling errors... not sure how that happened 4cf188a (marshallford)
## Release 2.16.0 (2018-08-13)
### Added
- add splitn function fccb0b0 (Helgi Þorbjörnsson)
- Add slice func df28ca7 (gongdo)
- Generate serial number a3bdffd (Cody Coons)
- Extract values of dict with values function df39312 (Lawrence Jones)
### Changed
- Modify panic message for list.slice ae38335 (gongdo)
- Minor improvement in code quality - Removed an unreachable piece of code at defaults.go#L26:6 - Resolve formatting issues. 5834241 (Abhishek Kashyap)
- Remove duplicated documentation 1d97af1 (Matthew Fisher)
- Test on go 1.11 49df809 (Helgi Þormar Þorbjörnsson)
### Fixed
- Fix file permissions c5f40b5 (gongdo)
- Fix example for buildCustomCert 7779e0d (Tin Lam)
## Release 2.15.0 (2018-04-02)
### Added
- #68 and #69: Add json helpers to docs (thanks @arunvelsriram)
- #66: Add ternary function (thanks @binoculars)
- #67: Allow keys function to take multiple dicts (thanks @binoculars)
- #89: Added sha1sum to crypto function (thanks @benkeil)
- #81: Allow customizing Root CA that used by genSignedCert (thanks @chenzhiwei)
- #92: Add travis testing for go 1.10
- #93: Adding appveyor config for windows testing
### Changed
- #90: Updating to more recent dependencies
- #73: replace satori/go.uuid with google/uuid (thanks @petterw)
### Fixed
- #76: Fixed documentation typos (thanks @Thiht)
- Fixed rounding issue on the `ago` function. Note, the removes support for Go 1.8 and older
## Release 2.14.1 (2017-12-01)
### Fixed
- #60: Fix typo in function name documentation (thanks @neil-ca-moore)
- #61: Removing line with {{ due to blocking github pages genertion
- #64: Update the list functions to handle int, string, and other slices for compatibility
## Release 2.14.0 (2017-10-06)
This new version of Sprig adds a set of functions for generating and working with SSL certificates.
- `genCA` generates an SSL Certificate Authority
- `genSelfSignedCert` generates an SSL self-signed certificate
- `genSignedCert` generates an SSL certificate and key based on a given CA
## Release 2.13.0 (2017-09-18)
This release adds new functions, including:
- `regexMatch`, `regexFindAll`, `regexFind`, `regexReplaceAll`, `regexReplaceAllLiteral`, and `regexSplit` to work with regular expressions
- `floor`, `ceil`, and `round` math functions
- `toDate` converts a string to a date
- `nindent` is just like `indent` but also prepends a new line
- `ago` returns the time from `time.Now`
### Added
- #40: Added basic regex functionality (thanks @alanquillin)
- #41: Added ceil floor and round functions (thanks @alanquillin)
- #48: Added toDate function (thanks @andreynering)
- #50: Added nindent function (thanks @binoculars)
- #46: Added ago function (thanks @slayer)
### Changed
- #51: Updated godocs to include new string functions (thanks @curtisallen)
- #49: Added ability to merge multiple dicts (thanks @binoculars)
## Release 2.12.0 (2017-05-17)
- `snakecase`, `camelcase`, and `shuffle` are three new string functions
- `fail` allows you to bail out of a template render when conditions are not met
## Release 2.11.0 (2017-05-02)
- Added `toJson` and `toPrettyJson`
- Added `merge`
- Refactored documentation
## Release 2.10.0 (2017-03-15)
- Added `semver` and `semverCompare` for Semantic Versions
- `list` replaces `tuple`
- Fixed issue with `join`
- Added `first`, `last`, `intial`, `rest`, `prepend`, `append`, `toString`, `toStrings`, `sortAlpha`, `reverse`, `coalesce`, `pluck`, `pick`, `compact`, `keys`, `omit`, `uniq`, `has`, `without`
## Release 2.9.0 (2017-02-23)
- Added `splitList` to split a list
- Added crypto functions of `genPrivateKey` and `derivePassword`
## Release 2.8.0 (2016-12-21)
- Added access to several path functions (`base`, `dir`, `clean`, `ext`, and `abs`)
- Added functions for _mutating_ dictionaries (`set`, `unset`, `hasKey`)
## Release 2.7.0 (2016-12-01)
- Added `sha256sum` to generate a hash of an input
- Added functions to convert a numeric or string to `int`, `int64`, `float64`
## Release 2.6.0 (2016-10-03)
- Added a `uuidv4` template function for generating UUIDs inside of a template.
## Release 2.5.0 (2016-08-19)
- New `trimSuffix`, `trimPrefix`, `hasSuffix`, and `hasPrefix` functions
- New aliases have been added for a few functions that didn't follow the naming conventions (`trimAll` and `abbrevBoth`)
- `trimall` and `abbrevboth` (notice the case) are deprecated and will be removed in 3.0.0
## Release 2.4.0 (2016-08-16)
- Adds two functions: `until` and `untilStep`
## Release 2.3.0 (2016-06-21)
- cat: Concatenate strings with whitespace separators.
- replace: Replace parts of a string: `replace " " "-" "Me First"` renders "Me-First"
- plural: Format plurals: `len "foo" | plural "one foo" "many foos"` renders "many foos"
- indent: Indent blocks of text in a way that is sensitive to "\n" characters.
## Release 2.2.0 (2016-04-21)
- Added a `genPrivateKey` function (Thanks @bacongobbler)
## Release 2.1.0 (2016-03-30)
- `default` now prints the default value when it does not receive a value down the pipeline. It is much safer now to do `{{.Foo | default "bar"}}`.
- Added accessors for "hermetic" functions. These return only functions that, when given the same input, produce the same output.
## Release 2.0.0 (2016-03-29)
Because we switched from `int` to `int64` as the return value for all integer math functions, the library's major version number has been incremented.
- `min` complements `max` (formerly `biggest`)
- `empty` indicates that a value is the empty value for its type
- `tuple` creates a tuple inside of a template: `{{$t := tuple "a", "b" "c"}}`
- `dict` creates a dictionary inside of a template `{{$d := dict "key1" "val1" "key2" "val2"}}`
- Date formatters have been added for HTML dates (as used in `date` input fields)
- Integer math functions can convert from a number of types, including `string` (via `strconv.ParseInt`).
## Release 1.2.0 (2016-02-01)
- Added quote and squote
- Added b32enc and b32dec
- add now takes varargs
- biggest now takes varargs
## Release 1.1.0 (2015-12-29)
- Added #4: Added contains function. strings.Contains, but with the arguments
switched to simplify common pipelines. (thanks krancour)
- Added Travis-CI testing support
## Release 1.0.0 (2015-12-23)
- Initial release
-19
View File
@@ -1,19 +0,0 @@
Copyright (C) 2013-2020 Masterminds
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
-73
View File
@@ -1,73 +0,0 @@
# Slim-Sprig: Template functions for Go templates [![Go Reference](https://pkg.go.dev/badge/github.com/go-task/slim-sprig/v3.svg)](https://pkg.go.dev/github.com/go-task/slim-sprig/v3)
Slim-Sprig is a fork of [Sprig](https://github.com/Masterminds/sprig), but with
all functions that depend on external (non standard library) or crypto packages
removed.
The reason for this is to make this library more lightweight. Most of these
functions (specially crypto ones) are not needed on most apps, but costs a lot
in terms of binary size and compilation time.
## Usage
**Template developers**: Please use Slim-Sprig's [function documentation](https://go-task.github.io/slim-sprig/) for
detailed instructions and code snippets for the >100 template functions available.
**Go developers**: If you'd like to include Slim-Sprig as a library in your program,
our API documentation is available [at GoDoc.org](http://godoc.org/github.com/go-task/slim-sprig).
For standard usage, read on.
### Load the Slim-Sprig library
To load the Slim-Sprig `FuncMap`:
```go
import (
"html/template"
"github.com/go-task/slim-sprig"
)
// This example illustrates that the FuncMap *must* be set before the
// templates themselves are loaded.
tpl := template.Must(
template.New("base").Funcs(sprig.FuncMap()).ParseGlob("*.html")
)
```
### Calling the functions inside of templates
By convention, all functions are lowercase. This seems to follow the Go
idiom for template functions (as opposed to template methods, which are
TitleCase). For example, this:
```
{{ "hello!" | upper | repeat 5 }}
```
produces this:
```
HELLO!HELLO!HELLO!HELLO!HELLO!
```
## Principles Driving Our Function Selection
We followed these principles to decide which functions to add and how to implement them:
- Use template functions to build layout. The following
types of operations are within the domain of template functions:
- Formatting
- Layout
- Simple type conversions
- Utilities that assist in handling common formatting and layout needs (e.g. arithmetic)
- Template functions should not return errors unless there is no way to print
a sensible value. For example, converting a string to an integer should not
produce an error if conversion fails. Instead, it should display a default
value.
- Simple math is necessary for grid layouts, pagers, and so on. Complex math
(anything other than arithmetic) should be done outside of templates.
- Template functions only deal with the data passed into them. They never retrieve
data from a source.
- Finally, do not override core Go template functions.
-12
View File
@@ -1,12 +0,0 @@
# https://taskfile.dev
version: '3'
tasks:
default:
cmds:
- task: test
test:
cmds:
- go test -v .
-24
View File
@@ -1,24 +0,0 @@
package sprig
import (
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash/adler32"
)
func sha256sum(input string) string {
hash := sha256.Sum256([]byte(input))
return hex.EncodeToString(hash[:])
}
func sha1sum(input string) string {
hash := sha1.Sum([]byte(input))
return hex.EncodeToString(hash[:])
}
func adler32sum(input string) string {
hash := adler32.Checksum([]byte(input))
return fmt.Sprintf("%d", hash)
}
-152
View File
@@ -1,152 +0,0 @@
package sprig
import (
"strconv"
"time"
)
// Given a format and a date, format the date string.
//
// Date can be a `time.Time` or an `int, int32, int64`.
// In the later case, it is treated as seconds since UNIX
// epoch.
func date(fmt string, date interface{}) string {
return dateInZone(fmt, date, "Local")
}
func htmlDate(date interface{}) string {
return dateInZone("2006-01-02", date, "Local")
}
func htmlDateInZone(date interface{}, zone string) string {
return dateInZone("2006-01-02", date, zone)
}
func dateInZone(fmt string, date interface{}, zone string) string {
var t time.Time
switch date := date.(type) {
default:
t = time.Now()
case time.Time:
t = date
case *time.Time:
t = *date
case int64:
t = time.Unix(date, 0)
case int:
t = time.Unix(int64(date), 0)
case int32:
t = time.Unix(int64(date), 0)
}
loc, err := time.LoadLocation(zone)
if err != nil {
loc, _ = time.LoadLocation("UTC")
}
return t.In(loc).Format(fmt)
}
func dateModify(fmt string, date time.Time) time.Time {
d, err := time.ParseDuration(fmt)
if err != nil {
return date
}
return date.Add(d)
}
func mustDateModify(fmt string, date time.Time) (time.Time, error) {
d, err := time.ParseDuration(fmt)
if err != nil {
return time.Time{}, err
}
return date.Add(d), nil
}
func dateAgo(date interface{}) string {
var t time.Time
switch date := date.(type) {
default:
t = time.Now()
case time.Time:
t = date
case int64:
t = time.Unix(date, 0)
case int:
t = time.Unix(int64(date), 0)
}
// Drop resolution to seconds
duration := time.Since(t).Round(time.Second)
return duration.String()
}
func duration(sec interface{}) string {
var n int64
switch value := sec.(type) {
default:
n = 0
case string:
n, _ = strconv.ParseInt(value, 10, 64)
case int64:
n = value
}
return (time.Duration(n) * time.Second).String()
}
func durationRound(duration interface{}) string {
var d time.Duration
switch duration := duration.(type) {
default:
d = 0
case string:
d, _ = time.ParseDuration(duration)
case int64:
d = time.Duration(duration)
case time.Time:
d = time.Since(duration)
}
u := uint64(d)
neg := d < 0
if neg {
u = -u
}
var (
year = uint64(time.Hour) * 24 * 365
month = uint64(time.Hour) * 24 * 30
day = uint64(time.Hour) * 24
hour = uint64(time.Hour)
minute = uint64(time.Minute)
second = uint64(time.Second)
)
switch {
case u > year:
return strconv.FormatUint(u/year, 10) + "y"
case u > month:
return strconv.FormatUint(u/month, 10) + "mo"
case u > day:
return strconv.FormatUint(u/day, 10) + "d"
case u > hour:
return strconv.FormatUint(u/hour, 10) + "h"
case u > minute:
return strconv.FormatUint(u/minute, 10) + "m"
case u > second:
return strconv.FormatUint(u/second, 10) + "s"
}
return "0s"
}
func toDate(fmt, str string) time.Time {
t, _ := time.ParseInLocation(fmt, str, time.Local)
return t
}
func mustToDate(fmt, str string) (time.Time, error) {
return time.ParseInLocation(fmt, str, time.Local)
}
func unixEpoch(date time.Time) string {
return strconv.FormatInt(date.Unix(), 10)
}
-163
View File
@@ -1,163 +0,0 @@
package sprig
import (
"bytes"
"encoding/json"
"math/rand"
"reflect"
"strings"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// dfault checks whether `given` is set, and returns default if not set.
//
// This returns `d` if `given` appears not to be set, and `given` otherwise.
//
// For numeric types 0 is unset.
// For strings, maps, arrays, and slices, len() = 0 is considered unset.
// For bool, false is unset.
// Structs are never considered unset.
//
// For everything else, including pointers, a nil value is unset.
func dfault(d interface{}, given ...interface{}) interface{} {
if empty(given) || empty(given[0]) {
return d
}
return given[0]
}
// empty returns true if the given value has the zero value for its type.
func empty(given interface{}) bool {
g := reflect.ValueOf(given)
if !g.IsValid() {
return true
}
// Basically adapted from text/template.isTrue
switch g.Kind() {
default:
return g.IsNil()
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
return g.Len() == 0
case reflect.Bool:
return !g.Bool()
case reflect.Complex64, reflect.Complex128:
return g.Complex() == 0
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return g.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return g.Uint() == 0
case reflect.Float32, reflect.Float64:
return g.Float() == 0
case reflect.Struct:
return false
}
}
// coalesce returns the first non-empty value.
func coalesce(v ...interface{}) interface{} {
for _, val := range v {
if !empty(val) {
return val
}
}
return nil
}
// all returns true if empty(x) is false for all values x in the list.
// If the list is empty, return true.
func all(v ...interface{}) bool {
for _, val := range v {
if empty(val) {
return false
}
}
return true
}
// any returns true if empty(x) is false for any x in the list.
// If the list is empty, return false.
func any(v ...interface{}) bool {
for _, val := range v {
if !empty(val) {
return true
}
}
return false
}
// fromJson decodes JSON into a structured value, ignoring errors.
func fromJson(v string) interface{} {
output, _ := mustFromJson(v)
return output
}
// mustFromJson decodes JSON into a structured value, returning errors.
func mustFromJson(v string) (interface{}, error) {
var output interface{}
err := json.Unmarshal([]byte(v), &output)
return output, err
}
// toJson encodes an item into a JSON string
func toJson(v interface{}) string {
output, _ := json.Marshal(v)
return string(output)
}
func mustToJson(v interface{}) (string, error) {
output, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(output), nil
}
// toPrettyJson encodes an item into a pretty (indented) JSON string
func toPrettyJson(v interface{}) string {
output, _ := json.MarshalIndent(v, "", " ")
return string(output)
}
func mustToPrettyJson(v interface{}) (string, error) {
output, err := json.MarshalIndent(v, "", " ")
if err != nil {
return "", err
}
return string(output), nil
}
// toRawJson encodes an item into a JSON string with no escaping of HTML characters.
func toRawJson(v interface{}) string {
output, err := mustToRawJson(v)
if err != nil {
panic(err)
}
return string(output)
}
// mustToRawJson encodes an item into a JSON string with no escaping of HTML characters.
func mustToRawJson(v interface{}) (string, error) {
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err := enc.Encode(&v)
if err != nil {
return "", err
}
return strings.TrimSuffix(buf.String(), "\n"), nil
}
// ternary returns the first value if the last value is true, otherwise returns the second value.
func ternary(vt interface{}, vf interface{}, v bool) interface{} {
if v {
return vt
}
return vf
}
-118
View File
@@ -1,118 +0,0 @@
package sprig
func get(d map[string]interface{}, key string) interface{} {
if val, ok := d[key]; ok {
return val
}
return ""
}
func set(d map[string]interface{}, key string, value interface{}) map[string]interface{} {
d[key] = value
return d
}
func unset(d map[string]interface{}, key string) map[string]interface{} {
delete(d, key)
return d
}
func hasKey(d map[string]interface{}, key string) bool {
_, ok := d[key]
return ok
}
func pluck(key string, d ...map[string]interface{}) []interface{} {
res := []interface{}{}
for _, dict := range d {
if val, ok := dict[key]; ok {
res = append(res, val)
}
}
return res
}
func keys(dicts ...map[string]interface{}) []string {
k := []string{}
for _, dict := range dicts {
for key := range dict {
k = append(k, key)
}
}
return k
}
func pick(dict map[string]interface{}, keys ...string) map[string]interface{} {
res := map[string]interface{}{}
for _, k := range keys {
if v, ok := dict[k]; ok {
res[k] = v
}
}
return res
}
func omit(dict map[string]interface{}, keys ...string) map[string]interface{} {
res := map[string]interface{}{}
omit := make(map[string]bool, len(keys))
for _, k := range keys {
omit[k] = true
}
for k, v := range dict {
if _, ok := omit[k]; !ok {
res[k] = v
}
}
return res
}
func dict(v ...interface{}) map[string]interface{} {
dict := map[string]interface{}{}
lenv := len(v)
for i := 0; i < lenv; i += 2 {
key := strval(v[i])
if i+1 >= lenv {
dict[key] = ""
continue
}
dict[key] = v[i+1]
}
return dict
}
func values(dict map[string]interface{}) []interface{} {
values := []interface{}{}
for _, value := range dict {
values = append(values, value)
}
return values
}
func dig(ps ...interface{}) (interface{}, error) {
if len(ps) < 3 {
panic("dig needs at least three arguments")
}
dict := ps[len(ps)-1].(map[string]interface{})
def := ps[len(ps)-2]
ks := make([]string, len(ps)-2)
for i := 0; i < len(ks); i++ {
ks[i] = ps[i].(string)
}
return digFromDict(dict, def, ks)
}
func digFromDict(dict map[string]interface{}, d interface{}, ks []string) (interface{}, error) {
k, ns := ks[0], ks[1:len(ks)]
step, has := dict[k]
if !has {
return d, nil
}
if len(ns) == 0 {
return step, nil
}
return digFromDict(step.(map[string]interface{}), d, ns)
}
-19
View File
@@ -1,19 +0,0 @@
/*
Package sprig provides template functions for Go.
This package contains a number of utility functions for working with data
inside of Go `html/template` and `text/template` files.
To add these functions, use the `template.Funcs()` method:
t := templates.New("foo").Funcs(sprig.FuncMap())
Note that you should add the function map before you parse any template files.
In several cases, Sprig reverses the order of arguments from the way they
appear in the standard library. This is to make it easier to pipe
arguments into functions.
See http://masterminds.github.io/sprig/ for more detailed documentation on each of the available functions.
*/
package sprig
-317
View File
@@ -1,317 +0,0 @@
package sprig
import (
"errors"
"html/template"
"math/rand"
"os"
"path"
"path/filepath"
"reflect"
"strconv"
"strings"
ttemplate "text/template"
"time"
)
// FuncMap produces the function map.
//
// Use this to pass the functions into the template engine:
//
// tpl := template.New("foo").Funcs(sprig.FuncMap()))
//
func FuncMap() template.FuncMap {
return HtmlFuncMap()
}
// HermeticTxtFuncMap returns a 'text/template'.FuncMap with only repeatable functions.
func HermeticTxtFuncMap() ttemplate.FuncMap {
r := TxtFuncMap()
for _, name := range nonhermeticFunctions {
delete(r, name)
}
return r
}
// HermeticHtmlFuncMap returns an 'html/template'.Funcmap with only repeatable functions.
func HermeticHtmlFuncMap() template.FuncMap {
r := HtmlFuncMap()
for _, name := range nonhermeticFunctions {
delete(r, name)
}
return r
}
// TxtFuncMap returns a 'text/template'.FuncMap
func TxtFuncMap() ttemplate.FuncMap {
return ttemplate.FuncMap(GenericFuncMap())
}
// HtmlFuncMap returns an 'html/template'.Funcmap
func HtmlFuncMap() template.FuncMap {
return template.FuncMap(GenericFuncMap())
}
// GenericFuncMap returns a copy of the basic function map as a map[string]interface{}.
func GenericFuncMap() map[string]interface{} {
gfm := make(map[string]interface{}, len(genericMap))
for k, v := range genericMap {
gfm[k] = v
}
return gfm
}
// These functions are not guaranteed to evaluate to the same result for given input, because they
// refer to the environment or global state.
var nonhermeticFunctions = []string{
// Date functions
"date",
"date_in_zone",
"date_modify",
"now",
"htmlDate",
"htmlDateInZone",
"dateInZone",
"dateModify",
// Strings
"randAlphaNum",
"randAlpha",
"randAscii",
"randNumeric",
"randBytes",
"uuidv4",
// OS
"env",
"expandenv",
// Network
"getHostByName",
}
var genericMap = map[string]interface{}{
"hello": func() string { return "Hello!" },
// Date functions
"ago": dateAgo,
"date": date,
"date_in_zone": dateInZone,
"date_modify": dateModify,
"dateInZone": dateInZone,
"dateModify": dateModify,
"duration": duration,
"durationRound": durationRound,
"htmlDate": htmlDate,
"htmlDateInZone": htmlDateInZone,
"must_date_modify": mustDateModify,
"mustDateModify": mustDateModify,
"mustToDate": mustToDate,
"now": time.Now,
"toDate": toDate,
"unixEpoch": unixEpoch,
// Strings
"trunc": trunc,
"trim": strings.TrimSpace,
"upper": strings.ToUpper,
"lower": strings.ToLower,
"title": strings.Title,
"substr": substring,
// Switch order so that "foo" | repeat 5
"repeat": func(count int, str string) string { return strings.Repeat(str, count) },
// Deprecated: Use trimAll.
"trimall": func(a, b string) string { return strings.Trim(b, a) },
// Switch order so that "$foo" | trimall "$"
"trimAll": func(a, b string) string { return strings.Trim(b, a) },
"trimSuffix": func(a, b string) string { return strings.TrimSuffix(b, a) },
"trimPrefix": func(a, b string) string { return strings.TrimPrefix(b, a) },
// Switch order so that "foobar" | contains "foo"
"contains": func(substr string, str string) bool { return strings.Contains(str, substr) },
"hasPrefix": func(substr string, str string) bool { return strings.HasPrefix(str, substr) },
"hasSuffix": func(substr string, str string) bool { return strings.HasSuffix(str, substr) },
"quote": quote,
"squote": squote,
"cat": cat,
"indent": indent,
"nindent": nindent,
"replace": replace,
"plural": plural,
"sha1sum": sha1sum,
"sha256sum": sha256sum,
"adler32sum": adler32sum,
"toString": strval,
// Wrap Atoi to stop errors.
"atoi": func(a string) int { i, _ := strconv.Atoi(a); return i },
"int64": toInt64,
"int": toInt,
"float64": toFloat64,
"seq": seq,
"toDecimal": toDecimal,
//"gt": func(a, b int) bool {return a > b},
//"gte": func(a, b int) bool {return a >= b},
//"lt": func(a, b int) bool {return a < b},
//"lte": func(a, b int) bool {return a <= b},
// split "/" foo/bar returns map[int]string{0: foo, 1: bar}
"split": split,
"splitList": func(sep, orig string) []string { return strings.Split(orig, sep) },
// splitn "/" foo/bar/fuu returns map[int]string{0: foo, 1: bar/fuu}
"splitn": splitn,
"toStrings": strslice,
"until": until,
"untilStep": untilStep,
// VERY basic arithmetic.
"add1": func(i interface{}) int64 { return toInt64(i) + 1 },
"add": func(i ...interface{}) int64 {
var a int64 = 0
for _, b := range i {
a += toInt64(b)
}
return a
},
"sub": func(a, b interface{}) int64 { return toInt64(a) - toInt64(b) },
"div": func(a, b interface{}) int64 { return toInt64(a) / toInt64(b) },
"mod": func(a, b interface{}) int64 { return toInt64(a) % toInt64(b) },
"mul": func(a interface{}, v ...interface{}) int64 {
val := toInt64(a)
for _, b := range v {
val = val * toInt64(b)
}
return val
},
"randInt": func(min, max int) int { return rand.Intn(max-min) + min },
"biggest": max,
"max": max,
"min": min,
"maxf": maxf,
"minf": minf,
"ceil": ceil,
"floor": floor,
"round": round,
// string slices. Note that we reverse the order b/c that's better
// for template processing.
"join": join,
"sortAlpha": sortAlpha,
// Defaults
"default": dfault,
"empty": empty,
"coalesce": coalesce,
"all": all,
"any": any,
"compact": compact,
"mustCompact": mustCompact,
"fromJson": fromJson,
"toJson": toJson,
"toPrettyJson": toPrettyJson,
"toRawJson": toRawJson,
"mustFromJson": mustFromJson,
"mustToJson": mustToJson,
"mustToPrettyJson": mustToPrettyJson,
"mustToRawJson": mustToRawJson,
"ternary": ternary,
// Reflection
"typeOf": typeOf,
"typeIs": typeIs,
"typeIsLike": typeIsLike,
"kindOf": kindOf,
"kindIs": kindIs,
"deepEqual": reflect.DeepEqual,
// OS:
"env": os.Getenv,
"expandenv": os.ExpandEnv,
// Network:
"getHostByName": getHostByName,
// Paths:
"base": path.Base,
"dir": path.Dir,
"clean": path.Clean,
"ext": path.Ext,
"isAbs": path.IsAbs,
// Filepaths:
"osBase": filepath.Base,
"osClean": filepath.Clean,
"osDir": filepath.Dir,
"osExt": filepath.Ext,
"osIsAbs": filepath.IsAbs,
// Encoding:
"b64enc": base64encode,
"b64dec": base64decode,
"b32enc": base32encode,
"b32dec": base32decode,
// Data Structures:
"tuple": list, // FIXME: with the addition of append/prepend these are no longer immutable.
"list": list,
"dict": dict,
"get": get,
"set": set,
"unset": unset,
"hasKey": hasKey,
"pluck": pluck,
"keys": keys,
"pick": pick,
"omit": omit,
"values": values,
"append": push, "push": push,
"mustAppend": mustPush, "mustPush": mustPush,
"prepend": prepend,
"mustPrepend": mustPrepend,
"first": first,
"mustFirst": mustFirst,
"rest": rest,
"mustRest": mustRest,
"last": last,
"mustLast": mustLast,
"initial": initial,
"mustInitial": mustInitial,
"reverse": reverse,
"mustReverse": mustReverse,
"uniq": uniq,
"mustUniq": mustUniq,
"without": without,
"mustWithout": mustWithout,
"has": has,
"mustHas": mustHas,
"slice": slice,
"mustSlice": mustSlice,
"concat": concat,
"dig": dig,
"chunk": chunk,
"mustChunk": mustChunk,
// Flow Control:
"fail": func(msg string) (string, error) { return "", errors.New(msg) },
// Regex
"regexMatch": regexMatch,
"mustRegexMatch": mustRegexMatch,
"regexFindAll": regexFindAll,
"mustRegexFindAll": mustRegexFindAll,
"regexFind": regexFind,
"mustRegexFind": mustRegexFind,
"regexReplaceAll": regexReplaceAll,
"mustRegexReplaceAll": mustRegexReplaceAll,
"regexReplaceAllLiteral": regexReplaceAllLiteral,
"mustRegexReplaceAllLiteral": mustRegexReplaceAllLiteral,
"regexSplit": regexSplit,
"mustRegexSplit": mustRegexSplit,
"regexQuoteMeta": regexQuoteMeta,
// URLs:
"urlParse": urlParse,
"urlJoin": urlJoin,
}
-464
View File
@@ -1,464 +0,0 @@
package sprig
import (
"fmt"
"math"
"reflect"
"sort"
)
// Reflection is used in these functions so that slices and arrays of strings,
// ints, and other types not implementing []interface{} can be worked with.
// For example, this is useful if you need to work on the output of regexs.
func list(v ...interface{}) []interface{} {
return v
}
func push(list interface{}, v interface{}) []interface{} {
l, err := mustPush(list, v)
if err != nil {
panic(err)
}
return l
}
func mustPush(list interface{}, v interface{}) ([]interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
nl := make([]interface{}, l)
for i := 0; i < l; i++ {
nl[i] = l2.Index(i).Interface()
}
return append(nl, v), nil
default:
return nil, fmt.Errorf("Cannot push on type %s", tp)
}
}
func prepend(list interface{}, v interface{}) []interface{} {
l, err := mustPrepend(list, v)
if err != nil {
panic(err)
}
return l
}
func mustPrepend(list interface{}, v interface{}) ([]interface{}, error) {
//return append([]interface{}{v}, list...)
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
nl := make([]interface{}, l)
for i := 0; i < l; i++ {
nl[i] = l2.Index(i).Interface()
}
return append([]interface{}{v}, nl...), nil
default:
return nil, fmt.Errorf("Cannot prepend on type %s", tp)
}
}
func chunk(size int, list interface{}) [][]interface{} {
l, err := mustChunk(size, list)
if err != nil {
panic(err)
}
return l
}
func mustChunk(size int, list interface{}) ([][]interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
cs := int(math.Floor(float64(l-1)/float64(size)) + 1)
nl := make([][]interface{}, cs)
for i := 0; i < cs; i++ {
clen := size
if i == cs-1 {
clen = int(math.Floor(math.Mod(float64(l), float64(size))))
if clen == 0 {
clen = size
}
}
nl[i] = make([]interface{}, clen)
for j := 0; j < clen; j++ {
ix := i*size + j
nl[i][j] = l2.Index(ix).Interface()
}
}
return nl, nil
default:
return nil, fmt.Errorf("Cannot chunk type %s", tp)
}
}
func last(list interface{}) interface{} {
l, err := mustLast(list)
if err != nil {
panic(err)
}
return l
}
func mustLast(list interface{}) (interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
if l == 0 {
return nil, nil
}
return l2.Index(l - 1).Interface(), nil
default:
return nil, fmt.Errorf("Cannot find last on type %s", tp)
}
}
func first(list interface{}) interface{} {
l, err := mustFirst(list)
if err != nil {
panic(err)
}
return l
}
func mustFirst(list interface{}) (interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
if l == 0 {
return nil, nil
}
return l2.Index(0).Interface(), nil
default:
return nil, fmt.Errorf("Cannot find first on type %s", tp)
}
}
func rest(list interface{}) []interface{} {
l, err := mustRest(list)
if err != nil {
panic(err)
}
return l
}
func mustRest(list interface{}) ([]interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
if l == 0 {
return nil, nil
}
nl := make([]interface{}, l-1)
for i := 1; i < l; i++ {
nl[i-1] = l2.Index(i).Interface()
}
return nl, nil
default:
return nil, fmt.Errorf("Cannot find rest on type %s", tp)
}
}
func initial(list interface{}) []interface{} {
l, err := mustInitial(list)
if err != nil {
panic(err)
}
return l
}
func mustInitial(list interface{}) ([]interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
if l == 0 {
return nil, nil
}
nl := make([]interface{}, l-1)
for i := 0; i < l-1; i++ {
nl[i] = l2.Index(i).Interface()
}
return nl, nil
default:
return nil, fmt.Errorf("Cannot find initial on type %s", tp)
}
}
func sortAlpha(list interface{}) []string {
k := reflect.Indirect(reflect.ValueOf(list)).Kind()
switch k {
case reflect.Slice, reflect.Array:
a := strslice(list)
s := sort.StringSlice(a)
s.Sort()
return s
}
return []string{strval(list)}
}
func reverse(v interface{}) []interface{} {
l, err := mustReverse(v)
if err != nil {
panic(err)
}
return l
}
func mustReverse(v interface{}) ([]interface{}, error) {
tp := reflect.TypeOf(v).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(v)
l := l2.Len()
// We do not sort in place because the incoming array should not be altered.
nl := make([]interface{}, l)
for i := 0; i < l; i++ {
nl[l-i-1] = l2.Index(i).Interface()
}
return nl, nil
default:
return nil, fmt.Errorf("Cannot find reverse on type %s", tp)
}
}
func compact(list interface{}) []interface{} {
l, err := mustCompact(list)
if err != nil {
panic(err)
}
return l
}
func mustCompact(list interface{}) ([]interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
nl := []interface{}{}
var item interface{}
for i := 0; i < l; i++ {
item = l2.Index(i).Interface()
if !empty(item) {
nl = append(nl, item)
}
}
return nl, nil
default:
return nil, fmt.Errorf("Cannot compact on type %s", tp)
}
}
func uniq(list interface{}) []interface{} {
l, err := mustUniq(list)
if err != nil {
panic(err)
}
return l
}
func mustUniq(list interface{}) ([]interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
dest := []interface{}{}
var item interface{}
for i := 0; i < l; i++ {
item = l2.Index(i).Interface()
if !inList(dest, item) {
dest = append(dest, item)
}
}
return dest, nil
default:
return nil, fmt.Errorf("Cannot find uniq on type %s", tp)
}
}
func inList(haystack []interface{}, needle interface{}) bool {
for _, h := range haystack {
if reflect.DeepEqual(needle, h) {
return true
}
}
return false
}
func without(list interface{}, omit ...interface{}) []interface{} {
l, err := mustWithout(list, omit...)
if err != nil {
panic(err)
}
return l
}
func mustWithout(list interface{}, omit ...interface{}) ([]interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
res := []interface{}{}
var item interface{}
for i := 0; i < l; i++ {
item = l2.Index(i).Interface()
if !inList(omit, item) {
res = append(res, item)
}
}
return res, nil
default:
return nil, fmt.Errorf("Cannot find without on type %s", tp)
}
}
func has(needle interface{}, haystack interface{}) bool {
l, err := mustHas(needle, haystack)
if err != nil {
panic(err)
}
return l
}
func mustHas(needle interface{}, haystack interface{}) (bool, error) {
if haystack == nil {
return false, nil
}
tp := reflect.TypeOf(haystack).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(haystack)
var item interface{}
l := l2.Len()
for i := 0; i < l; i++ {
item = l2.Index(i).Interface()
if reflect.DeepEqual(needle, item) {
return true, nil
}
}
return false, nil
default:
return false, fmt.Errorf("Cannot find has on type %s", tp)
}
}
// $list := [1, 2, 3, 4, 5]
// slice $list -> list[0:5] = list[:]
// slice $list 0 3 -> list[0:3] = list[:3]
// slice $list 3 5 -> list[3:5]
// slice $list 3 -> list[3:5] = list[3:]
func slice(list interface{}, indices ...interface{}) interface{} {
l, err := mustSlice(list, indices...)
if err != nil {
panic(err)
}
return l
}
func mustSlice(list interface{}, indices ...interface{}) (interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
if l == 0 {
return nil, nil
}
var start, end int
if len(indices) > 0 {
start = toInt(indices[0])
}
if len(indices) < 2 {
end = l
} else {
end = toInt(indices[1])
}
return l2.Slice(start, end).Interface(), nil
default:
return nil, fmt.Errorf("list should be type of slice or array but %s", tp)
}
}
func concat(lists ...interface{}) interface{} {
var res []interface{}
for _, list := range lists {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
for i := 0; i < l2.Len(); i++ {
res = append(res, l2.Index(i).Interface())
}
default:
panic(fmt.Sprintf("Cannot concat type %s as list", tp))
}
}
return res
}
-12
View File
@@ -1,12 +0,0 @@
package sprig
import (
"math/rand"
"net"
)
func getHostByName(name string) string {
addrs, _ := net.LookupHost(name)
//TODO: add error handing when release v3 comes out
return addrs[rand.Intn(len(addrs))]
}
-228
View File
@@ -1,228 +0,0 @@
package sprig
import (
"fmt"
"math"
"reflect"
"strconv"
"strings"
)
// toFloat64 converts 64-bit floats
func toFloat64(v interface{}) float64 {
if str, ok := v.(string); ok {
iv, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0
}
return iv
}
val := reflect.Indirect(reflect.ValueOf(v))
switch val.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return float64(val.Int())
case reflect.Uint8, reflect.Uint16, reflect.Uint32:
return float64(val.Uint())
case reflect.Uint, reflect.Uint64:
return float64(val.Uint())
case reflect.Float32, reflect.Float64:
return val.Float()
case reflect.Bool:
if val.Bool() {
return 1
}
return 0
default:
return 0
}
}
func toInt(v interface{}) int {
//It's not optimal. Bud I don't want duplicate toInt64 code.
return int(toInt64(v))
}
// toInt64 converts integer types to 64-bit integers
func toInt64(v interface{}) int64 {
if str, ok := v.(string); ok {
iv, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return 0
}
return iv
}
val := reflect.Indirect(reflect.ValueOf(v))
switch val.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return val.Int()
case reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(val.Uint())
case reflect.Uint, reflect.Uint64:
tv := val.Uint()
if tv <= math.MaxInt64 {
return int64(tv)
}
// TODO: What is the sensible thing to do here?
return math.MaxInt64
case reflect.Float32, reflect.Float64:
return int64(val.Float())
case reflect.Bool:
if val.Bool() {
return 1
}
return 0
default:
return 0
}
}
func max(a interface{}, i ...interface{}) int64 {
aa := toInt64(a)
for _, b := range i {
bb := toInt64(b)
if bb > aa {
aa = bb
}
}
return aa
}
func maxf(a interface{}, i ...interface{}) float64 {
aa := toFloat64(a)
for _, b := range i {
bb := toFloat64(b)
aa = math.Max(aa, bb)
}
return aa
}
func min(a interface{}, i ...interface{}) int64 {
aa := toInt64(a)
for _, b := range i {
bb := toInt64(b)
if bb < aa {
aa = bb
}
}
return aa
}
func minf(a interface{}, i ...interface{}) float64 {
aa := toFloat64(a)
for _, b := range i {
bb := toFloat64(b)
aa = math.Min(aa, bb)
}
return aa
}
func until(count int) []int {
step := 1
if count < 0 {
step = -1
}
return untilStep(0, count, step)
}
func untilStep(start, stop, step int) []int {
v := []int{}
if stop < start {
if step >= 0 {
return v
}
for i := start; i > stop; i += step {
v = append(v, i)
}
return v
}
if step <= 0 {
return v
}
for i := start; i < stop; i += step {
v = append(v, i)
}
return v
}
func floor(a interface{}) float64 {
aa := toFloat64(a)
return math.Floor(aa)
}
func ceil(a interface{}) float64 {
aa := toFloat64(a)
return math.Ceil(aa)
}
func round(a interface{}, p int, rOpt ...float64) float64 {
roundOn := .5
if len(rOpt) > 0 {
roundOn = rOpt[0]
}
val := toFloat64(a)
places := toFloat64(p)
var round float64
pow := math.Pow(10, places)
digit := pow * val
_, div := math.Modf(digit)
if div >= roundOn {
round = math.Ceil(digit)
} else {
round = math.Floor(digit)
}
return round / pow
}
// converts unix octal to decimal
func toDecimal(v interface{}) int64 {
result, err := strconv.ParseInt(fmt.Sprint(v), 8, 64)
if err != nil {
return 0
}
return result
}
func seq(params ...int) string {
increment := 1
switch len(params) {
case 0:
return ""
case 1:
start := 1
end := params[0]
if end < start {
increment = -1
}
return intArrayToString(untilStep(start, end+increment, increment), " ")
case 3:
start := params[0]
end := params[2]
step := params[1]
if end < start {
increment = -1
if step > 0 {
return ""
}
}
return intArrayToString(untilStep(start, end+increment, step), " ")
case 2:
start := params[0]
end := params[1]
step := 1
if end < start {
step = -1
}
return intArrayToString(untilStep(start, end+step, step), " ")
default:
return ""
}
}
func intArrayToString(slice []int, delimeter string) string {
return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(slice)), delimeter), "[]")
}
-28
View File
@@ -1,28 +0,0 @@
package sprig
import (
"fmt"
"reflect"
)
// typeIs returns true if the src is the type named in target.
func typeIs(target string, src interface{}) bool {
return target == typeOf(src)
}
func typeIsLike(target string, src interface{}) bool {
t := typeOf(src)
return target == t || "*"+target == t
}
func typeOf(src interface{}) string {
return fmt.Sprintf("%T", src)
}
func kindIs(target string, src interface{}) bool {
return target == kindOf(src)
}
func kindOf(src interface{}) string {
return reflect.ValueOf(src).Kind().String()
}
-83
View File
@@ -1,83 +0,0 @@
package sprig
import (
"regexp"
)
func regexMatch(regex string, s string) bool {
match, _ := regexp.MatchString(regex, s)
return match
}
func mustRegexMatch(regex string, s string) (bool, error) {
return regexp.MatchString(regex, s)
}
func regexFindAll(regex string, s string, n int) []string {
r := regexp.MustCompile(regex)
return r.FindAllString(s, n)
}
func mustRegexFindAll(regex string, s string, n int) ([]string, error) {
r, err := regexp.Compile(regex)
if err != nil {
return []string{}, err
}
return r.FindAllString(s, n), nil
}
func regexFind(regex string, s string) string {
r := regexp.MustCompile(regex)
return r.FindString(s)
}
func mustRegexFind(regex string, s string) (string, error) {
r, err := regexp.Compile(regex)
if err != nil {
return "", err
}
return r.FindString(s), nil
}
func regexReplaceAll(regex string, s string, repl string) string {
r := regexp.MustCompile(regex)
return r.ReplaceAllString(s, repl)
}
func mustRegexReplaceAll(regex string, s string, repl string) (string, error) {
r, err := regexp.Compile(regex)
if err != nil {
return "", err
}
return r.ReplaceAllString(s, repl), nil
}
func regexReplaceAllLiteral(regex string, s string, repl string) string {
r := regexp.MustCompile(regex)
return r.ReplaceAllLiteralString(s, repl)
}
func mustRegexReplaceAllLiteral(regex string, s string, repl string) (string, error) {
r, err := regexp.Compile(regex)
if err != nil {
return "", err
}
return r.ReplaceAllLiteralString(s, repl), nil
}
func regexSplit(regex string, s string, n int) []string {
r := regexp.MustCompile(regex)
return r.Split(s, n)
}
func mustRegexSplit(regex string, s string, n int) ([]string, error) {
r, err := regexp.Compile(regex)
if err != nil {
return []string{}, err
}
return r.Split(s, n), nil
}
func regexQuoteMeta(s string) string {
return regexp.QuoteMeta(s)
}
-189
View File
@@ -1,189 +0,0 @@
package sprig
import (
"encoding/base32"
"encoding/base64"
"fmt"
"reflect"
"strconv"
"strings"
)
func base64encode(v string) string {
return base64.StdEncoding.EncodeToString([]byte(v))
}
func base64decode(v string) string {
data, err := base64.StdEncoding.DecodeString(v)
if err != nil {
return err.Error()
}
return string(data)
}
func base32encode(v string) string {
return base32.StdEncoding.EncodeToString([]byte(v))
}
func base32decode(v string) string {
data, err := base32.StdEncoding.DecodeString(v)
if err != nil {
return err.Error()
}
return string(data)
}
func quote(str ...interface{}) string {
out := make([]string, 0, len(str))
for _, s := range str {
if s != nil {
out = append(out, fmt.Sprintf("%q", strval(s)))
}
}
return strings.Join(out, " ")
}
func squote(str ...interface{}) string {
out := make([]string, 0, len(str))
for _, s := range str {
if s != nil {
out = append(out, fmt.Sprintf("'%v'", s))
}
}
return strings.Join(out, " ")
}
func cat(v ...interface{}) string {
v = removeNilElements(v)
r := strings.TrimSpace(strings.Repeat("%v ", len(v)))
return fmt.Sprintf(r, v...)
}
func indent(spaces int, v string) string {
pad := strings.Repeat(" ", spaces)
return pad + strings.Replace(v, "\n", "\n"+pad, -1)
}
func nindent(spaces int, v string) string {
return "\n" + indent(spaces, v)
}
func replace(old, new, src string) string {
return strings.Replace(src, old, new, -1)
}
func plural(one, many string, count int) string {
if count == 1 {
return one
}
return many
}
func strslice(v interface{}) []string {
switch v := v.(type) {
case []string:
return v
case []interface{}:
b := make([]string, 0, len(v))
for _, s := range v {
if s != nil {
b = append(b, strval(s))
}
}
return b
default:
val := reflect.ValueOf(v)
switch val.Kind() {
case reflect.Array, reflect.Slice:
l := val.Len()
b := make([]string, 0, l)
for i := 0; i < l; i++ {
value := val.Index(i).Interface()
if value != nil {
b = append(b, strval(value))
}
}
return b
default:
if v == nil {
return []string{}
}
return []string{strval(v)}
}
}
}
func removeNilElements(v []interface{}) []interface{} {
newSlice := make([]interface{}, 0, len(v))
for _, i := range v {
if i != nil {
newSlice = append(newSlice, i)
}
}
return newSlice
}
func strval(v interface{}) string {
switch v := v.(type) {
case string:
return v
case []byte:
return string(v)
case error:
return v.Error()
case fmt.Stringer:
return v.String()
default:
return fmt.Sprintf("%v", v)
}
}
func trunc(c int, s string) string {
if c < 0 && len(s)+c > 0 {
return s[len(s)+c:]
}
if c >= 0 && len(s) > c {
return s[:c]
}
return s
}
func join(sep string, v interface{}) string {
return strings.Join(strslice(v), sep)
}
func split(sep, orig string) map[string]string {
parts := strings.Split(orig, sep)
res := make(map[string]string, len(parts))
for i, v := range parts {
res["_"+strconv.Itoa(i)] = v
}
return res
}
func splitn(sep string, n int, orig string) map[string]string {
parts := strings.SplitN(orig, sep, n)
res := make(map[string]string, len(parts))
for i, v := range parts {
res["_"+strconv.Itoa(i)] = v
}
return res
}
// substring creates a substring of the given string.
//
// If start is < 0, this calls string[:end].
//
// If start is >= 0 and end < 0 or end bigger than s length, this calls string[start:]
//
// Otherwise, this calls string[start, end].
func substring(start, end int, s string) string {
if start < 0 {
return s[:end]
}
if end < 0 || end > len(s) {
return s[start:]
}
return s[start:end]
}
-66
View File
@@ -1,66 +0,0 @@
package sprig
import (
"fmt"
"net/url"
"reflect"
)
func dictGetOrEmpty(dict map[string]interface{}, key string) string {
value, ok := dict[key]
if !ok {
return ""
}
tp := reflect.TypeOf(value).Kind()
if tp != reflect.String {
panic(fmt.Sprintf("unable to parse %s key, must be of type string, but %s found", key, tp.String()))
}
return reflect.ValueOf(value).String()
}
// parses given URL to return dict object
func urlParse(v string) map[string]interface{} {
dict := map[string]interface{}{}
parsedURL, err := url.Parse(v)
if err != nil {
panic(fmt.Sprintf("unable to parse url: %s", err))
}
dict["scheme"] = parsedURL.Scheme
dict["host"] = parsedURL.Host
dict["hostname"] = parsedURL.Hostname()
dict["path"] = parsedURL.Path
dict["query"] = parsedURL.RawQuery
dict["opaque"] = parsedURL.Opaque
dict["fragment"] = parsedURL.Fragment
if parsedURL.User != nil {
dict["userinfo"] = parsedURL.User.String()
} else {
dict["userinfo"] = ""
}
return dict
}
// join given dict to URL string
func urlJoin(d map[string]interface{}) string {
resURL := url.URL{
Scheme: dictGetOrEmpty(d, "scheme"),
Host: dictGetOrEmpty(d, "host"),
Path: dictGetOrEmpty(d, "path"),
RawQuery: dictGetOrEmpty(d, "query"),
Opaque: dictGetOrEmpty(d, "opaque"),
Fragment: dictGetOrEmpty(d, "fragment"),
}
userinfo := dictGetOrEmpty(d, "userinfo")
var user *url.Userinfo
if userinfo != "" {
tempURL, err := url.Parse(fmt.Sprintf("proto://%s@host", userinfo))
if err != nil {
panic(fmt.Sprintf("unable to parse userinfo in dict: %s", err))
}
user = tempURL.User
}
resURL.User = user
return resURL.String()
}
-7
View File
@@ -1,7 +0,0 @@
# This is the official list of pprof authors for copyright purposes.
# This file is distinct from the CONTRIBUTORS files.
# See the latter for an explanation.
# Names should be added to this file as:
# Name or Organization <email address>
# The email address is not required for organizations.
Google Inc.
-16
View File
@@ -1,16 +0,0 @@
# People who have agreed to one of the CLAs and can contribute patches.
# The AUTHORS file lists the copyright holders; this file
# lists people. For example, Google employees are listed here
# but not in AUTHORS, because Google holds the copyright.
#
# https://developers.google.com/open-source/cla/individual
# https://developers.google.com/open-source/cla/corporate
#
# Names should be added to this file as:
# Name <email address>
Raul Silvera <rsilvera@google.com>
Tipp Moseley <tipp@google.com>
Hyoun Kyu Cho <netforce@google.com>
Martin Spier <spiermar@gmail.com>
Taco de Wolff <tacodewolff@gmail.com>
Andrew Hunter <andrewhhunter@gmail.com>
-202
View File
@@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-596
View File
@@ -1,596 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package profile
import (
"errors"
"sort"
"strings"
)
func (p *Profile) decoder() []decoder {
return profileDecoder
}
// preEncode populates the unexported fields to be used by encode
// (with suffix X) from the corresponding exported fields. The
// exported fields are cleared up to facilitate testing.
func (p *Profile) preEncode() {
strings := make(map[string]int)
addString(strings, "")
for _, st := range p.SampleType {
st.typeX = addString(strings, st.Type)
st.unitX = addString(strings, st.Unit)
}
for _, s := range p.Sample {
s.labelX = nil
var keys []string
for k := range s.Label {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := s.Label[k]
for _, v := range vs {
s.labelX = append(s.labelX,
label{
keyX: addString(strings, k),
strX: addString(strings, v),
},
)
}
}
var numKeys []string
for k := range s.NumLabel {
numKeys = append(numKeys, k)
}
sort.Strings(numKeys)
for _, k := range numKeys {
keyX := addString(strings, k)
vs := s.NumLabel[k]
units := s.NumUnit[k]
for i, v := range vs {
var unitX int64
if len(units) != 0 {
unitX = addString(strings, units[i])
}
s.labelX = append(s.labelX,
label{
keyX: keyX,
numX: v,
unitX: unitX,
},
)
}
}
s.locationIDX = make([]uint64, len(s.Location))
for i, loc := range s.Location {
s.locationIDX[i] = loc.ID
}
}
for _, m := range p.Mapping {
m.fileX = addString(strings, m.File)
m.buildIDX = addString(strings, m.BuildID)
}
for _, l := range p.Location {
for i, ln := range l.Line {
if ln.Function != nil {
l.Line[i].functionIDX = ln.Function.ID
} else {
l.Line[i].functionIDX = 0
}
}
if l.Mapping != nil {
l.mappingIDX = l.Mapping.ID
} else {
l.mappingIDX = 0
}
}
for _, f := range p.Function {
f.nameX = addString(strings, f.Name)
f.systemNameX = addString(strings, f.SystemName)
f.filenameX = addString(strings, f.Filename)
}
p.dropFramesX = addString(strings, p.DropFrames)
p.keepFramesX = addString(strings, p.KeepFrames)
if pt := p.PeriodType; pt != nil {
pt.typeX = addString(strings, pt.Type)
pt.unitX = addString(strings, pt.Unit)
}
p.commentX = nil
for _, c := range p.Comments {
p.commentX = append(p.commentX, addString(strings, c))
}
p.defaultSampleTypeX = addString(strings, p.DefaultSampleType)
p.docURLX = addString(strings, p.DocURL)
p.stringTable = make([]string, len(strings))
for s, i := range strings {
p.stringTable[i] = s
}
}
func (p *Profile) encode(b *buffer) {
for _, x := range p.SampleType {
encodeMessage(b, 1, x)
}
for _, x := range p.Sample {
encodeMessage(b, 2, x)
}
for _, x := range p.Mapping {
encodeMessage(b, 3, x)
}
for _, x := range p.Location {
encodeMessage(b, 4, x)
}
for _, x := range p.Function {
encodeMessage(b, 5, x)
}
encodeStrings(b, 6, p.stringTable)
encodeInt64Opt(b, 7, p.dropFramesX)
encodeInt64Opt(b, 8, p.keepFramesX)
encodeInt64Opt(b, 9, p.TimeNanos)
encodeInt64Opt(b, 10, p.DurationNanos)
if pt := p.PeriodType; pt != nil && (pt.typeX != 0 || pt.unitX != 0) {
encodeMessage(b, 11, p.PeriodType)
}
encodeInt64Opt(b, 12, p.Period)
encodeInt64s(b, 13, p.commentX)
encodeInt64(b, 14, p.defaultSampleTypeX)
encodeInt64Opt(b, 15, p.docURLX)
}
var profileDecoder = []decoder{
nil, // 0
// repeated ValueType sample_type = 1
func(b *buffer, m message) error {
x := new(ValueType)
pp := m.(*Profile)
pp.SampleType = append(pp.SampleType, x)
return decodeMessage(b, x)
},
// repeated Sample sample = 2
func(b *buffer, m message) error {
x := new(Sample)
pp := m.(*Profile)
pp.Sample = append(pp.Sample, x)
return decodeMessage(b, x)
},
// repeated Mapping mapping = 3
func(b *buffer, m message) error {
x := new(Mapping)
pp := m.(*Profile)
pp.Mapping = append(pp.Mapping, x)
return decodeMessage(b, x)
},
// repeated Location location = 4
func(b *buffer, m message) error {
x := new(Location)
x.Line = b.tmpLines[:0] // Use shared space temporarily
pp := m.(*Profile)
pp.Location = append(pp.Location, x)
err := decodeMessage(b, x)
b.tmpLines = x.Line[:0]
// Copy to shrink size and detach from shared space.
x.Line = append([]Line(nil), x.Line...)
return err
},
// repeated Function function = 5
func(b *buffer, m message) error {
x := new(Function)
pp := m.(*Profile)
pp.Function = append(pp.Function, x)
return decodeMessage(b, x)
},
// repeated string string_table = 6
func(b *buffer, m message) error {
err := decodeStrings(b, &m.(*Profile).stringTable)
if err != nil {
return err
}
if m.(*Profile).stringTable[0] != "" {
return errors.New("string_table[0] must be ''")
}
return nil
},
// int64 drop_frames = 7
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).dropFramesX) },
// int64 keep_frames = 8
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).keepFramesX) },
// int64 time_nanos = 9
func(b *buffer, m message) error {
if m.(*Profile).TimeNanos != 0 {
return errConcatProfile
}
return decodeInt64(b, &m.(*Profile).TimeNanos)
},
// int64 duration_nanos = 10
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).DurationNanos) },
// ValueType period_type = 11
func(b *buffer, m message) error {
x := new(ValueType)
pp := m.(*Profile)
pp.PeriodType = x
return decodeMessage(b, x)
},
// int64 period = 12
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).Period) },
// repeated int64 comment = 13
func(b *buffer, m message) error { return decodeInt64s(b, &m.(*Profile).commentX) },
// int64 defaultSampleType = 14
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).defaultSampleTypeX) },
// string doc_link = 15;
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).docURLX) },
}
// postDecode takes the unexported fields populated by decode (with
// suffix X) and populates the corresponding exported fields.
// The unexported fields are cleared up to facilitate testing.
func (p *Profile) postDecode() error {
var err error
mappings := make(map[uint64]*Mapping, len(p.Mapping))
mappingIds := make([]*Mapping, len(p.Mapping)+1)
for _, m := range p.Mapping {
m.File, err = getString(p.stringTable, &m.fileX, err)
m.BuildID, err = getString(p.stringTable, &m.buildIDX, err)
if m.ID < uint64(len(mappingIds)) {
mappingIds[m.ID] = m
} else {
mappings[m.ID] = m
}
// If this a main linux kernel mapping with a relocation symbol suffix
// ("[kernel.kallsyms]_text"), extract said suffix.
// It is fairly hacky to handle at this level, but the alternatives appear even worse.
const prefix = "[kernel.kallsyms]"
if strings.HasPrefix(m.File, prefix) {
m.KernelRelocationSymbol = m.File[len(prefix):]
}
}
functions := make(map[uint64]*Function, len(p.Function))
functionIds := make([]*Function, len(p.Function)+1)
for _, f := range p.Function {
f.Name, err = getString(p.stringTable, &f.nameX, err)
f.SystemName, err = getString(p.stringTable, &f.systemNameX, err)
f.Filename, err = getString(p.stringTable, &f.filenameX, err)
if f.ID < uint64(len(functionIds)) {
functionIds[f.ID] = f
} else {
functions[f.ID] = f
}
}
locations := make(map[uint64]*Location, len(p.Location))
locationIds := make([]*Location, len(p.Location)+1)
for _, l := range p.Location {
if id := l.mappingIDX; id < uint64(len(mappingIds)) {
l.Mapping = mappingIds[id]
} else {
l.Mapping = mappings[id]
}
l.mappingIDX = 0
for i, ln := range l.Line {
if id := ln.functionIDX; id != 0 {
l.Line[i].functionIDX = 0
if id < uint64(len(functionIds)) {
l.Line[i].Function = functionIds[id]
} else {
l.Line[i].Function = functions[id]
}
}
}
if l.ID < uint64(len(locationIds)) {
locationIds[l.ID] = l
} else {
locations[l.ID] = l
}
}
for _, st := range p.SampleType {
st.Type, err = getString(p.stringTable, &st.typeX, err)
st.Unit, err = getString(p.stringTable, &st.unitX, err)
}
// Pre-allocate space for all locations.
numLocations := 0
for _, s := range p.Sample {
numLocations += len(s.locationIDX)
}
locBuffer := make([]*Location, numLocations)
for _, s := range p.Sample {
if len(s.labelX) > 0 {
labels := make(map[string][]string, len(s.labelX))
numLabels := make(map[string][]int64, len(s.labelX))
numUnits := make(map[string][]string, len(s.labelX))
for _, l := range s.labelX {
var key, value string
key, err = getString(p.stringTable, &l.keyX, err)
if l.strX != 0 {
value, err = getString(p.stringTable, &l.strX, err)
labels[key] = append(labels[key], value)
} else if l.numX != 0 || l.unitX != 0 {
numValues := numLabels[key]
units := numUnits[key]
if l.unitX != 0 {
var unit string
unit, err = getString(p.stringTable, &l.unitX, err)
units = padStringArray(units, len(numValues))
numUnits[key] = append(units, unit)
}
numLabels[key] = append(numLabels[key], l.numX)
}
}
if len(labels) > 0 {
s.Label = labels
}
if len(numLabels) > 0 {
s.NumLabel = numLabels
for key, units := range numUnits {
if len(units) > 0 {
numUnits[key] = padStringArray(units, len(numLabels[key]))
}
}
s.NumUnit = numUnits
}
}
s.Location = locBuffer[:len(s.locationIDX)]
locBuffer = locBuffer[len(s.locationIDX):]
for i, lid := range s.locationIDX {
if lid < uint64(len(locationIds)) {
s.Location[i] = locationIds[lid]
} else {
s.Location[i] = locations[lid]
}
}
s.locationIDX = nil
}
p.DropFrames, err = getString(p.stringTable, &p.dropFramesX, err)
p.KeepFrames, err = getString(p.stringTable, &p.keepFramesX, err)
if pt := p.PeriodType; pt == nil {
p.PeriodType = &ValueType{}
}
if pt := p.PeriodType; pt != nil {
pt.Type, err = getString(p.stringTable, &pt.typeX, err)
pt.Unit, err = getString(p.stringTable, &pt.unitX, err)
}
for _, i := range p.commentX {
var c string
c, err = getString(p.stringTable, &i, err)
p.Comments = append(p.Comments, c)
}
p.commentX = nil
p.DefaultSampleType, err = getString(p.stringTable, &p.defaultSampleTypeX, err)
p.DocURL, err = getString(p.stringTable, &p.docURLX, err)
p.stringTable = nil
return err
}
// padStringArray pads arr with enough empty strings to make arr
// length l when arr's length is less than l.
func padStringArray(arr []string, l int) []string {
if l <= len(arr) {
return arr
}
return append(arr, make([]string, l-len(arr))...)
}
func (p *ValueType) decoder() []decoder {
return valueTypeDecoder
}
func (p *ValueType) encode(b *buffer) {
encodeInt64Opt(b, 1, p.typeX)
encodeInt64Opt(b, 2, p.unitX)
}
var valueTypeDecoder = []decoder{
nil, // 0
// optional int64 type = 1
func(b *buffer, m message) error { return decodeInt64(b, &m.(*ValueType).typeX) },
// optional int64 unit = 2
func(b *buffer, m message) error { return decodeInt64(b, &m.(*ValueType).unitX) },
}
func (p *Sample) decoder() []decoder {
return sampleDecoder
}
func (p *Sample) encode(b *buffer) {
encodeUint64s(b, 1, p.locationIDX)
encodeInt64s(b, 2, p.Value)
for _, x := range p.labelX {
encodeMessage(b, 3, x)
}
}
var sampleDecoder = []decoder{
nil, // 0
// repeated uint64 location = 1
func(b *buffer, m message) error { return decodeUint64s(b, &m.(*Sample).locationIDX) },
// repeated int64 value = 2
func(b *buffer, m message) error { return decodeInt64s(b, &m.(*Sample).Value) },
// repeated Label label = 3
func(b *buffer, m message) error {
s := m.(*Sample)
n := len(s.labelX)
s.labelX = append(s.labelX, label{})
return decodeMessage(b, &s.labelX[n])
},
}
func (p label) decoder() []decoder {
return labelDecoder
}
func (p label) encode(b *buffer) {
encodeInt64Opt(b, 1, p.keyX)
encodeInt64Opt(b, 2, p.strX)
encodeInt64Opt(b, 3, p.numX)
encodeInt64Opt(b, 4, p.unitX)
}
var labelDecoder = []decoder{
nil, // 0
// optional int64 key = 1
func(b *buffer, m message) error { return decodeInt64(b, &m.(*label).keyX) },
// optional int64 str = 2
func(b *buffer, m message) error { return decodeInt64(b, &m.(*label).strX) },
// optional int64 num = 3
func(b *buffer, m message) error { return decodeInt64(b, &m.(*label).numX) },
// optional int64 num = 4
func(b *buffer, m message) error { return decodeInt64(b, &m.(*label).unitX) },
}
func (p *Mapping) decoder() []decoder {
return mappingDecoder
}
func (p *Mapping) encode(b *buffer) {
encodeUint64Opt(b, 1, p.ID)
encodeUint64Opt(b, 2, p.Start)
encodeUint64Opt(b, 3, p.Limit)
encodeUint64Opt(b, 4, p.Offset)
encodeInt64Opt(b, 5, p.fileX)
encodeInt64Opt(b, 6, p.buildIDX)
encodeBoolOpt(b, 7, p.HasFunctions)
encodeBoolOpt(b, 8, p.HasFilenames)
encodeBoolOpt(b, 9, p.HasLineNumbers)
encodeBoolOpt(b, 10, p.HasInlineFrames)
}
var mappingDecoder = []decoder{
nil, // 0
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).ID) }, // optional uint64 id = 1
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Start) }, // optional uint64 memory_offset = 2
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Limit) }, // optional uint64 memory_limit = 3
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Offset) }, // optional uint64 file_offset = 4
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Mapping).fileX) }, // optional int64 filename = 5
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Mapping).buildIDX) }, // optional int64 build_id = 6
func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasFunctions) }, // optional bool has_functions = 7
func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasFilenames) }, // optional bool has_filenames = 8
func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasLineNumbers) }, // optional bool has_line_numbers = 9
func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasInlineFrames) }, // optional bool has_inline_frames = 10
}
func (p *Location) decoder() []decoder {
return locationDecoder
}
func (p *Location) encode(b *buffer) {
encodeUint64Opt(b, 1, p.ID)
encodeUint64Opt(b, 2, p.mappingIDX)
encodeUint64Opt(b, 3, p.Address)
for i := range p.Line {
encodeMessage(b, 4, &p.Line[i])
}
encodeBoolOpt(b, 5, p.IsFolded)
}
var locationDecoder = []decoder{
nil, // 0
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).ID) }, // optional uint64 id = 1;
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).mappingIDX) }, // optional uint64 mapping_id = 2;
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).Address) }, // optional uint64 address = 3;
func(b *buffer, m message) error { // repeated Line line = 4
pp := m.(*Location)
n := len(pp.Line)
pp.Line = append(pp.Line, Line{})
return decodeMessage(b, &pp.Line[n])
},
func(b *buffer, m message) error { return decodeBool(b, &m.(*Location).IsFolded) }, // optional bool is_folded = 5;
}
func (p *Line) decoder() []decoder {
return lineDecoder
}
func (p *Line) encode(b *buffer) {
encodeUint64Opt(b, 1, p.functionIDX)
encodeInt64Opt(b, 2, p.Line)
encodeInt64Opt(b, 3, p.Column)
}
var lineDecoder = []decoder{
nil, // 0
// optional uint64 function_id = 1
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Line).functionIDX) },
// optional int64 line = 2
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Line).Line) },
// optional int64 column = 3
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Line).Column) },
}
func (p *Function) decoder() []decoder {
return functionDecoder
}
func (p *Function) encode(b *buffer) {
encodeUint64Opt(b, 1, p.ID)
encodeInt64Opt(b, 2, p.nameX)
encodeInt64Opt(b, 3, p.systemNameX)
encodeInt64Opt(b, 4, p.filenameX)
encodeInt64Opt(b, 5, p.StartLine)
}
var functionDecoder = []decoder{
nil, // 0
// optional uint64 id = 1
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Function).ID) },
// optional int64 function_name = 2
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).nameX) },
// optional int64 function_system_name = 3
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).systemNameX) },
// repeated int64 filename = 4
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).filenameX) },
// optional int64 start_line = 5
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).StartLine) },
}
func addString(strings map[string]int, s string) int64 {
i, ok := strings[s]
if !ok {
i = len(strings)
strings[s] = i
}
return int64(i)
}
func getString(strings []string, strng *int64, err error) (string, error) {
if err != nil {
return "", err
}
s := int(*strng)
if s < 0 || s >= len(strings) {
return "", errMalformed
}
*strng = 0
return strings[s], nil
}
-274
View File
@@ -1,274 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package profile
// Implements methods to filter samples from profiles.
import "regexp"
// FilterSamplesByName filters the samples in a profile and only keeps
// samples where at least one frame matches focus but none match ignore.
// Returns true is the corresponding regexp matched at least one sample.
func (p *Profile) FilterSamplesByName(focus, ignore, hide, show *regexp.Regexp) (fm, im, hm, hnm bool) {
if focus == nil && ignore == nil && hide == nil && show == nil {
fm = true // Missing focus implies a match
return
}
focusOrIgnore := make(map[uint64]bool)
hidden := make(map[uint64]bool)
for _, l := range p.Location {
if ignore != nil && l.matchesName(ignore) {
im = true
focusOrIgnore[l.ID] = false
} else if focus == nil || l.matchesName(focus) {
fm = true
focusOrIgnore[l.ID] = true
}
if hide != nil && l.matchesName(hide) {
hm = true
l.Line = l.unmatchedLines(hide)
if len(l.Line) == 0 {
hidden[l.ID] = true
}
}
if show != nil {
l.Line = l.matchedLines(show)
if len(l.Line) == 0 {
hidden[l.ID] = true
} else {
hnm = true
}
}
}
s := make([]*Sample, 0, len(p.Sample))
for _, sample := range p.Sample {
if focusedAndNotIgnored(sample.Location, focusOrIgnore) {
if len(hidden) > 0 {
var locs []*Location
for _, loc := range sample.Location {
if !hidden[loc.ID] {
locs = append(locs, loc)
}
}
if len(locs) == 0 {
// Remove sample with no locations (by not adding it to s).
continue
}
sample.Location = locs
}
s = append(s, sample)
}
}
p.Sample = s
return
}
// ShowFrom drops all stack frames above the highest matching frame and returns
// whether a match was found. If showFrom is nil it returns false and does not
// modify the profile.
//
// Example: consider a sample with frames [A, B, C, B], where A is the root.
// ShowFrom(nil) returns false and has frames [A, B, C, B].
// ShowFrom(A) returns true and has frames [A, B, C, B].
// ShowFrom(B) returns true and has frames [B, C, B].
// ShowFrom(C) returns true and has frames [C, B].
// ShowFrom(D) returns false and drops the sample because no frames remain.
func (p *Profile) ShowFrom(showFrom *regexp.Regexp) (matched bool) {
if showFrom == nil {
return false
}
// showFromLocs stores location IDs that matched ShowFrom.
showFromLocs := make(map[uint64]bool)
// Apply to locations.
for _, loc := range p.Location {
if filterShowFromLocation(loc, showFrom) {
showFromLocs[loc.ID] = true
matched = true
}
}
// For all samples, strip locations after the highest matching one.
s := make([]*Sample, 0, len(p.Sample))
for _, sample := range p.Sample {
for i := len(sample.Location) - 1; i >= 0; i-- {
if showFromLocs[sample.Location[i].ID] {
sample.Location = sample.Location[:i+1]
s = append(s, sample)
break
}
}
}
p.Sample = s
return matched
}
// filterShowFromLocation tests a showFrom regex against a location, removes
// lines after the last match and returns whether a match was found. If the
// mapping is matched, then all lines are kept.
func filterShowFromLocation(loc *Location, showFrom *regexp.Regexp) bool {
if m := loc.Mapping; m != nil && showFrom.MatchString(m.File) {
return true
}
if i := loc.lastMatchedLineIndex(showFrom); i >= 0 {
loc.Line = loc.Line[:i+1]
return true
}
return false
}
// lastMatchedLineIndex returns the index of the last line that matches a regex,
// or -1 if no match is found.
func (loc *Location) lastMatchedLineIndex(re *regexp.Regexp) int {
for i := len(loc.Line) - 1; i >= 0; i-- {
if fn := loc.Line[i].Function; fn != nil {
if re.MatchString(fn.Name) || re.MatchString(fn.Filename) {
return i
}
}
}
return -1
}
// FilterTagsByName filters the tags in a profile and only keeps
// tags that match show and not hide.
func (p *Profile) FilterTagsByName(show, hide *regexp.Regexp) (sm, hm bool) {
matchRemove := func(name string) bool {
matchShow := show == nil || show.MatchString(name)
matchHide := hide != nil && hide.MatchString(name)
if matchShow {
sm = true
}
if matchHide {
hm = true
}
return !matchShow || matchHide
}
for _, s := range p.Sample {
for lab := range s.Label {
if matchRemove(lab) {
delete(s.Label, lab)
}
}
for lab := range s.NumLabel {
if matchRemove(lab) {
delete(s.NumLabel, lab)
}
}
}
return
}
// matchesName returns whether the location matches the regular
// expression. It checks any available function names, file names, and
// mapping object filename.
func (loc *Location) matchesName(re *regexp.Regexp) bool {
for _, ln := range loc.Line {
if fn := ln.Function; fn != nil {
if re.MatchString(fn.Name) || re.MatchString(fn.Filename) {
return true
}
}
}
if m := loc.Mapping; m != nil && re.MatchString(m.File) {
return true
}
return false
}
// unmatchedLines returns the lines in the location that do not match
// the regular expression.
func (loc *Location) unmatchedLines(re *regexp.Regexp) []Line {
if m := loc.Mapping; m != nil && re.MatchString(m.File) {
return nil
}
var lines []Line
for _, ln := range loc.Line {
if fn := ln.Function; fn != nil {
if re.MatchString(fn.Name) || re.MatchString(fn.Filename) {
continue
}
}
lines = append(lines, ln)
}
return lines
}
// matchedLines returns the lines in the location that match
// the regular expression.
func (loc *Location) matchedLines(re *regexp.Regexp) []Line {
if m := loc.Mapping; m != nil && re.MatchString(m.File) {
return loc.Line
}
var lines []Line
for _, ln := range loc.Line {
if fn := ln.Function; fn != nil {
if !re.MatchString(fn.Name) && !re.MatchString(fn.Filename) {
continue
}
}
lines = append(lines, ln)
}
return lines
}
// focusedAndNotIgnored looks up a slice of ids against a map of
// focused/ignored locations. The map only contains locations that are
// explicitly focused or ignored. Returns whether there is at least
// one focused location but no ignored locations.
func focusedAndNotIgnored(locs []*Location, m map[uint64]bool) bool {
var f bool
for _, loc := range locs {
if focus, focusOrIgnore := m[loc.ID]; focusOrIgnore {
if focus {
// Found focused location. Must keep searching in case there
// is an ignored one as well.
f = true
} else {
// Found ignored location. Can return false right away.
return false
}
}
}
return f
}
// TagMatch selects tags for filtering
type TagMatch func(s *Sample) bool
// FilterSamplesByTag removes all samples from the profile, except
// those that match focus and do not match the ignore regular
// expression.
func (p *Profile) FilterSamplesByTag(focus, ignore TagMatch) (fm, im bool) {
samples := make([]*Sample, 0, len(p.Sample))
for _, s := range p.Sample {
focused, ignored := true, false
if focus != nil {
focused = focus(s)
}
if ignore != nil {
ignored = ignore(s)
}
fm = fm || focused
im = im || ignored
if focused && !ignored {
samples = append(samples, s)
}
}
p.Sample = samples
return
}
-64
View File
@@ -1,64 +0,0 @@
// Copyright 2016 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package profile
import (
"fmt"
"strconv"
"strings"
)
// SampleIndexByName returns the appropriate index for a value of sample index.
// If numeric, it returns the number, otherwise it looks up the text in the
// profile sample types.
func (p *Profile) SampleIndexByName(sampleIndex string) (int, error) {
if sampleIndex == "" {
if dst := p.DefaultSampleType; dst != "" {
for i, t := range sampleTypes(p) {
if t == dst {
return i, nil
}
}
}
// By default select the last sample value
return len(p.SampleType) - 1, nil
}
if i, err := strconv.Atoi(sampleIndex); err == nil {
if i < 0 || i >= len(p.SampleType) {
return 0, fmt.Errorf("sample_index %s is outside the range [0..%d]", sampleIndex, len(p.SampleType)-1)
}
return i, nil
}
// Remove the inuse_ prefix to support legacy pprof options
// "inuse_space" and "inuse_objects" for profiles containing types
// "space" and "objects".
noInuse := strings.TrimPrefix(sampleIndex, "inuse_")
for i, t := range p.SampleType {
if t.Type == sampleIndex || t.Type == noInuse {
return i, nil
}
}
return 0, fmt.Errorf("sample_index %q must be one of: %v", sampleIndex, sampleTypes(p))
}
func sampleTypes(p *Profile) []string {
types := make([]string, len(p.SampleType))
for i, t := range p.SampleType {
types[i] = t.Type
}
return types
}
-315
View File
@@ -1,315 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file implements parsers to convert java legacy profiles into
// the profile.proto format.
package profile
import (
"bytes"
"fmt"
"io"
"path/filepath"
"regexp"
"strconv"
"strings"
)
var (
attributeRx = regexp.MustCompile(`([\w ]+)=([\w ]+)`)
javaSampleRx = regexp.MustCompile(` *(\d+) +(\d+) +@ +([ x0-9a-f]*)`)
javaLocationRx = regexp.MustCompile(`^\s*0x([[:xdigit:]]+)\s+(.*)\s*$`)
javaLocationFileLineRx = regexp.MustCompile(`^(.*)\s+\((.+):(-?[[:digit:]]+)\)$`)
javaLocationPathRx = regexp.MustCompile(`^(.*)\s+\((.*)\)$`)
)
// javaCPUProfile returns a new Profile from profilez data.
// b is the profile bytes after the header, period is the profiling
// period, and parse is a function to parse 8-byte chunks from the
// profile in its native endianness.
func javaCPUProfile(b []byte, period int64, parse func(b []byte) (uint64, []byte)) (*Profile, error) {
p := &Profile{
Period: period * 1000,
PeriodType: &ValueType{Type: "cpu", Unit: "nanoseconds"},
SampleType: []*ValueType{{Type: "samples", Unit: "count"}, {Type: "cpu", Unit: "nanoseconds"}},
}
var err error
var locs map[uint64]*Location
if b, locs, err = parseCPUSamples(b, parse, false, p); err != nil {
return nil, err
}
if err = parseJavaLocations(b, locs, p); err != nil {
return nil, err
}
// Strip out addresses for better merge.
if err = p.Aggregate(true, true, true, true, false, false); err != nil {
return nil, err
}
return p, nil
}
// parseJavaProfile returns a new profile from heapz or contentionz
// data. b is the profile bytes after the header.
func parseJavaProfile(b []byte) (*Profile, error) {
h := bytes.SplitAfterN(b, []byte("\n"), 2)
if len(h) < 2 {
return nil, errUnrecognized
}
p := &Profile{
PeriodType: &ValueType{},
}
header := string(bytes.TrimSpace(h[0]))
var err error
var pType string
switch header {
case "--- heapz 1 ---":
pType = "heap"
case "--- contentionz 1 ---":
pType = "contention"
default:
return nil, errUnrecognized
}
if b, err = parseJavaHeader(pType, h[1], p); err != nil {
return nil, err
}
var locs map[uint64]*Location
if b, locs, err = parseJavaSamples(pType, b, p); err != nil {
return nil, err
}
if err = parseJavaLocations(b, locs, p); err != nil {
return nil, err
}
// Strip out addresses for better merge.
if err = p.Aggregate(true, true, true, true, false, false); err != nil {
return nil, err
}
return p, nil
}
// parseJavaHeader parses the attribute section on a java profile and
// populates a profile. Returns the remainder of the buffer after all
// attributes.
func parseJavaHeader(pType string, b []byte, p *Profile) ([]byte, error) {
nextNewLine := bytes.IndexByte(b, byte('\n'))
for nextNewLine != -1 {
line := string(bytes.TrimSpace(b[0:nextNewLine]))
if line != "" {
h := attributeRx.FindStringSubmatch(line)
if h == nil {
// Not a valid attribute, exit.
return b, nil
}
attribute, value := strings.TrimSpace(h[1]), strings.TrimSpace(h[2])
var err error
switch pType + "/" + attribute {
case "heap/format", "cpu/format", "contention/format":
if value != "java" {
return nil, errUnrecognized
}
case "heap/resolution":
p.SampleType = []*ValueType{
{Type: "inuse_objects", Unit: "count"},
{Type: "inuse_space", Unit: value},
}
case "contention/resolution":
p.SampleType = []*ValueType{
{Type: "contentions", Unit: "count"},
{Type: "delay", Unit: value},
}
case "contention/sampling period":
p.PeriodType = &ValueType{
Type: "contentions", Unit: "count",
}
if p.Period, err = strconv.ParseInt(value, 0, 64); err != nil {
return nil, fmt.Errorf("failed to parse attribute %s: %v", line, err)
}
case "contention/ms since reset":
millis, err := strconv.ParseInt(value, 0, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse attribute %s: %v", line, err)
}
p.DurationNanos = millis * 1000 * 1000
default:
return nil, errUnrecognized
}
}
// Grab next line.
b = b[nextNewLine+1:]
nextNewLine = bytes.IndexByte(b, byte('\n'))
}
return b, nil
}
// parseJavaSamples parses the samples from a java profile and
// populates the Samples in a profile. Returns the remainder of the
// buffer after the samples.
func parseJavaSamples(pType string, b []byte, p *Profile) ([]byte, map[uint64]*Location, error) {
nextNewLine := bytes.IndexByte(b, byte('\n'))
locs := make(map[uint64]*Location)
for nextNewLine != -1 {
line := string(bytes.TrimSpace(b[0:nextNewLine]))
if line != "" {
sample := javaSampleRx.FindStringSubmatch(line)
if sample == nil {
// Not a valid sample, exit.
return b, locs, nil
}
// Java profiles have data/fields inverted compared to other
// profile types.
var err error
value1, value2, value3 := sample[2], sample[1], sample[3]
addrs, err := parseHexAddresses(value3)
if err != nil {
return nil, nil, fmt.Errorf("malformed sample: %s: %v", line, err)
}
var sloc []*Location
for _, addr := range addrs {
loc := locs[addr]
if locs[addr] == nil {
loc = &Location{
Address: addr,
}
p.Location = append(p.Location, loc)
locs[addr] = loc
}
sloc = append(sloc, loc)
}
s := &Sample{
Value: make([]int64, 2),
Location: sloc,
}
if s.Value[0], err = strconv.ParseInt(value1, 0, 64); err != nil {
return nil, nil, fmt.Errorf("parsing sample %s: %v", line, err)
}
if s.Value[1], err = strconv.ParseInt(value2, 0, 64); err != nil {
return nil, nil, fmt.Errorf("parsing sample %s: %v", line, err)
}
switch pType {
case "heap":
const javaHeapzSamplingRate = 524288 // 512K
if s.Value[0] == 0 {
return nil, nil, fmt.Errorf("parsing sample %s: second value must be non-zero", line)
}
s.NumLabel = map[string][]int64{"bytes": {s.Value[1] / s.Value[0]}}
s.Value[0], s.Value[1] = scaleHeapSample(s.Value[0], s.Value[1], javaHeapzSamplingRate)
case "contention":
if period := p.Period; period != 0 {
s.Value[0] = s.Value[0] * p.Period
s.Value[1] = s.Value[1] * p.Period
}
}
p.Sample = append(p.Sample, s)
}
// Grab next line.
b = b[nextNewLine+1:]
nextNewLine = bytes.IndexByte(b, byte('\n'))
}
return b, locs, nil
}
// parseJavaLocations parses the location information in a java
// profile and populates the Locations in a profile. It uses the
// location addresses from the profile as both the ID of each
// location.
func parseJavaLocations(b []byte, locs map[uint64]*Location, p *Profile) error {
r := bytes.NewBuffer(b)
fns := make(map[string]*Function)
for {
line, err := r.ReadString('\n')
if err != nil {
if err != io.EOF {
return err
}
if line == "" {
break
}
}
if line = strings.TrimSpace(line); line == "" {
continue
}
jloc := javaLocationRx.FindStringSubmatch(line)
if len(jloc) != 3 {
continue
}
addr, err := strconv.ParseUint(jloc[1], 16, 64)
if err != nil {
return fmt.Errorf("parsing sample %s: %v", line, err)
}
loc := locs[addr]
if loc == nil {
// Unused/unseen
continue
}
var lineFunc, lineFile string
var lineNo int64
if fileLine := javaLocationFileLineRx.FindStringSubmatch(jloc[2]); len(fileLine) == 4 {
// Found a line of the form: "function (file:line)"
lineFunc, lineFile = fileLine[1], fileLine[2]
if n, err := strconv.ParseInt(fileLine[3], 10, 64); err == nil && n > 0 {
lineNo = n
}
} else if filePath := javaLocationPathRx.FindStringSubmatch(jloc[2]); len(filePath) == 3 {
// If there's not a file:line, it's a shared library path.
// The path isn't interesting, so just give the .so.
lineFunc, lineFile = filePath[1], filepath.Base(filePath[2])
} else if strings.Contains(jloc[2], "generated stub/JIT") {
lineFunc = "STUB"
} else {
// Treat whole line as the function name. This is used by the
// java agent for internal states such as "GC" or "VM".
lineFunc = jloc[2]
}
fn := fns[lineFunc]
if fn == nil {
fn = &Function{
Name: lineFunc,
SystemName: lineFunc,
Filename: lineFile,
}
fns[lineFunc] = fn
p.Function = append(p.Function, fn)
}
loc.Line = []Line{
{
Function: fn,
Line: lineNo,
},
}
loc.Address = 0
}
p.remapLocationIDs()
p.remapFunctionIDs()
p.remapMappingIDs()
return nil
}
File diff suppressed because it is too large Load Diff
-674
View File
@@ -1,674 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package profile
import (
"encoding/binary"
"fmt"
"sort"
"strconv"
"strings"
)
// Compact performs garbage collection on a profile to remove any
// unreferenced fields. This is useful to reduce the size of a profile
// after samples or locations have been removed.
func (p *Profile) Compact() *Profile {
p, _ = Merge([]*Profile{p})
return p
}
// Merge merges all the profiles in profs into a single Profile.
// Returns a new profile independent of the input profiles. The merged
// profile is compacted to eliminate unused samples, locations,
// functions and mappings. Profiles must have identical profile sample
// and period types or the merge will fail. profile.Period of the
// resulting profile will be the maximum of all profiles, and
// profile.TimeNanos will be the earliest nonzero one. Merges are
// associative with the caveat of the first profile having some
// specialization in how headers are combined. There may be other
// subtleties now or in the future regarding associativity.
func Merge(srcs []*Profile) (*Profile, error) {
if len(srcs) == 0 {
return nil, fmt.Errorf("no profiles to merge")
}
p, err := combineHeaders(srcs)
if err != nil {
return nil, err
}
pm := &profileMerger{
p: p,
samples: make(map[sampleKey]*Sample, len(srcs[0].Sample)),
locations: make(map[locationKey]*Location, len(srcs[0].Location)),
functions: make(map[functionKey]*Function, len(srcs[0].Function)),
mappings: make(map[mappingKey]*Mapping, len(srcs[0].Mapping)),
}
for _, src := range srcs {
// Clear the profile-specific hash tables
pm.locationsByID = makeLocationIDMap(len(src.Location))
pm.functionsByID = make(map[uint64]*Function, len(src.Function))
pm.mappingsByID = make(map[uint64]mapInfo, len(src.Mapping))
if len(pm.mappings) == 0 && len(src.Mapping) > 0 {
// The Mapping list has the property that the first mapping
// represents the main binary. Take the first Mapping we see,
// otherwise the operations below will add mappings in an
// arbitrary order.
pm.mapMapping(src.Mapping[0])
}
for _, s := range src.Sample {
if !isZeroSample(s) {
pm.mapSample(s)
}
}
}
for _, s := range p.Sample {
if isZeroSample(s) {
// If there are any zero samples, re-merge the profile to GC
// them.
return Merge([]*Profile{p})
}
}
return p, nil
}
// Normalize normalizes the source profile by multiplying each value in profile by the
// ratio of the sum of the base profile's values of that sample type to the sum of the
// source profile's value of that sample type.
func (p *Profile) Normalize(pb *Profile) error {
if err := p.compatible(pb); err != nil {
return err
}
baseVals := make([]int64, len(p.SampleType))
for _, s := range pb.Sample {
for i, v := range s.Value {
baseVals[i] += v
}
}
srcVals := make([]int64, len(p.SampleType))
for _, s := range p.Sample {
for i, v := range s.Value {
srcVals[i] += v
}
}
normScale := make([]float64, len(baseVals))
for i := range baseVals {
if srcVals[i] == 0 {
normScale[i] = 0.0
} else {
normScale[i] = float64(baseVals[i]) / float64(srcVals[i])
}
}
p.ScaleN(normScale)
return nil
}
func isZeroSample(s *Sample) bool {
for _, v := range s.Value {
if v != 0 {
return false
}
}
return true
}
type profileMerger struct {
p *Profile
// Memoization tables within a profile.
locationsByID locationIDMap
functionsByID map[uint64]*Function
mappingsByID map[uint64]mapInfo
// Memoization tables for profile entities.
samples map[sampleKey]*Sample
locations map[locationKey]*Location
functions map[functionKey]*Function
mappings map[mappingKey]*Mapping
}
type mapInfo struct {
m *Mapping
offset int64
}
func (pm *profileMerger) mapSample(src *Sample) *Sample {
// Check memoization table
k := pm.sampleKey(src)
if ss, ok := pm.samples[k]; ok {
for i, v := range src.Value {
ss.Value[i] += v
}
return ss
}
// Make new sample.
s := &Sample{
Location: make([]*Location, len(src.Location)),
Value: make([]int64, len(src.Value)),
Label: make(map[string][]string, len(src.Label)),
NumLabel: make(map[string][]int64, len(src.NumLabel)),
NumUnit: make(map[string][]string, len(src.NumLabel)),
}
for i, l := range src.Location {
s.Location[i] = pm.mapLocation(l)
}
for k, v := range src.Label {
vv := make([]string, len(v))
copy(vv, v)
s.Label[k] = vv
}
for k, v := range src.NumLabel {
u := src.NumUnit[k]
vv := make([]int64, len(v))
uu := make([]string, len(u))
copy(vv, v)
copy(uu, u)
s.NumLabel[k] = vv
s.NumUnit[k] = uu
}
copy(s.Value, src.Value)
pm.samples[k] = s
pm.p.Sample = append(pm.p.Sample, s)
return s
}
func (pm *profileMerger) sampleKey(sample *Sample) sampleKey {
// Accumulate contents into a string.
var buf strings.Builder
buf.Grow(64) // Heuristic to avoid extra allocs
// encode a number
putNumber := func(v uint64) {
var num [binary.MaxVarintLen64]byte
n := binary.PutUvarint(num[:], v)
buf.Write(num[:n])
}
// encode a string prefixed with its length.
putDelimitedString := func(s string) {
putNumber(uint64(len(s)))
buf.WriteString(s)
}
for _, l := range sample.Location {
// Get the location in the merged profile, which may have a different ID.
if loc := pm.mapLocation(l); loc != nil {
putNumber(loc.ID)
}
}
putNumber(0) // Delimiter
for _, l := range sortedKeys1(sample.Label) {
putDelimitedString(l)
values := sample.Label[l]
putNumber(uint64(len(values)))
for _, v := range values {
putDelimitedString(v)
}
}
for _, l := range sortedKeys2(sample.NumLabel) {
putDelimitedString(l)
values := sample.NumLabel[l]
putNumber(uint64(len(values)))
for _, v := range values {
putNumber(uint64(v))
}
units := sample.NumUnit[l]
putNumber(uint64(len(units)))
for _, v := range units {
putDelimitedString(v)
}
}
return sampleKey(buf.String())
}
type sampleKey string
// sortedKeys1 returns the sorted keys found in a string->[]string map.
//
// Note: this is currently non-generic since github pprof runs golint,
// which does not support generics. When that issue is fixed, it can
// be merged with sortedKeys2 and made into a generic function.
func sortedKeys1(m map[string][]string) []string {
if len(m) == 0 {
return nil
}
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
// sortedKeys2 returns the sorted keys found in a string->[]int64 map.
//
// Note: this is currently non-generic since github pprof runs golint,
// which does not support generics. When that issue is fixed, it can
// be merged with sortedKeys1 and made into a generic function.
func sortedKeys2(m map[string][]int64) []string {
if len(m) == 0 {
return nil
}
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func (pm *profileMerger) mapLocation(src *Location) *Location {
if src == nil {
return nil
}
if l := pm.locationsByID.get(src.ID); l != nil {
return l
}
mi := pm.mapMapping(src.Mapping)
l := &Location{
ID: uint64(len(pm.p.Location) + 1),
Mapping: mi.m,
Address: uint64(int64(src.Address) + mi.offset),
Line: make([]Line, len(src.Line)),
IsFolded: src.IsFolded,
}
for i, ln := range src.Line {
l.Line[i] = pm.mapLine(ln)
}
// Check memoization table. Must be done on the remapped location to
// account for the remapped mapping ID.
k := l.key()
if ll, ok := pm.locations[k]; ok {
pm.locationsByID.set(src.ID, ll)
return ll
}
pm.locationsByID.set(src.ID, l)
pm.locations[k] = l
pm.p.Location = append(pm.p.Location, l)
return l
}
// key generates locationKey to be used as a key for maps.
func (l *Location) key() locationKey {
key := locationKey{
addr: l.Address,
isFolded: l.IsFolded,
}
if l.Mapping != nil {
// Normalizes address to handle address space randomization.
key.addr -= l.Mapping.Start
key.mappingID = l.Mapping.ID
}
lines := make([]string, len(l.Line)*3)
for i, line := range l.Line {
if line.Function != nil {
lines[i*2] = strconv.FormatUint(line.Function.ID, 16)
}
lines[i*2+1] = strconv.FormatInt(line.Line, 16)
lines[i*2+2] = strconv.FormatInt(line.Column, 16)
}
key.lines = strings.Join(lines, "|")
return key
}
type locationKey struct {
addr, mappingID uint64
lines string
isFolded bool
}
func (pm *profileMerger) mapMapping(src *Mapping) mapInfo {
if src == nil {
return mapInfo{}
}
if mi, ok := pm.mappingsByID[src.ID]; ok {
return mi
}
// Check memoization tables.
mk := src.key()
if m, ok := pm.mappings[mk]; ok {
mi := mapInfo{m, int64(m.Start) - int64(src.Start)}
pm.mappingsByID[src.ID] = mi
return mi
}
m := &Mapping{
ID: uint64(len(pm.p.Mapping) + 1),
Start: src.Start,
Limit: src.Limit,
Offset: src.Offset,
File: src.File,
KernelRelocationSymbol: src.KernelRelocationSymbol,
BuildID: src.BuildID,
HasFunctions: src.HasFunctions,
HasFilenames: src.HasFilenames,
HasLineNumbers: src.HasLineNumbers,
HasInlineFrames: src.HasInlineFrames,
}
pm.p.Mapping = append(pm.p.Mapping, m)
// Update memoization tables.
pm.mappings[mk] = m
mi := mapInfo{m, 0}
pm.mappingsByID[src.ID] = mi
return mi
}
// key generates encoded strings of Mapping to be used as a key for
// maps.
func (m *Mapping) key() mappingKey {
// Normalize addresses to handle address space randomization.
// Round up to next 4K boundary to avoid minor discrepancies.
const mapsizeRounding = 0x1000
size := m.Limit - m.Start
size = size + mapsizeRounding - 1
size = size - (size % mapsizeRounding)
key := mappingKey{
size: size,
offset: m.Offset,
}
switch {
case m.BuildID != "":
key.buildIDOrFile = m.BuildID
case m.File != "":
key.buildIDOrFile = m.File
default:
// A mapping containing neither build ID nor file name is a fake mapping. A
// key with empty buildIDOrFile is used for fake mappings so that they are
// treated as the same mapping during merging.
}
return key
}
type mappingKey struct {
size, offset uint64
buildIDOrFile string
}
func (pm *profileMerger) mapLine(src Line) Line {
ln := Line{
Function: pm.mapFunction(src.Function),
Line: src.Line,
Column: src.Column,
}
return ln
}
func (pm *profileMerger) mapFunction(src *Function) *Function {
if src == nil {
return nil
}
if f, ok := pm.functionsByID[src.ID]; ok {
return f
}
k := src.key()
if f, ok := pm.functions[k]; ok {
pm.functionsByID[src.ID] = f
return f
}
f := &Function{
ID: uint64(len(pm.p.Function) + 1),
Name: src.Name,
SystemName: src.SystemName,
Filename: src.Filename,
StartLine: src.StartLine,
}
pm.functions[k] = f
pm.functionsByID[src.ID] = f
pm.p.Function = append(pm.p.Function, f)
return f
}
// key generates a struct to be used as a key for maps.
func (f *Function) key() functionKey {
return functionKey{
f.StartLine,
f.Name,
f.SystemName,
f.Filename,
}
}
type functionKey struct {
startLine int64
name, systemName, fileName string
}
// combineHeaders checks that all profiles can be merged and returns
// their combined profile.
func combineHeaders(srcs []*Profile) (*Profile, error) {
for _, s := range srcs[1:] {
if err := srcs[0].compatible(s); err != nil {
return nil, err
}
}
var timeNanos, durationNanos, period int64
var comments []string
seenComments := map[string]bool{}
var docURL string
var defaultSampleType string
for _, s := range srcs {
if timeNanos == 0 || s.TimeNanos < timeNanos {
timeNanos = s.TimeNanos
}
durationNanos += s.DurationNanos
if period == 0 || period < s.Period {
period = s.Period
}
for _, c := range s.Comments {
if seen := seenComments[c]; !seen {
comments = append(comments, c)
seenComments[c] = true
}
}
if defaultSampleType == "" {
defaultSampleType = s.DefaultSampleType
}
if docURL == "" {
docURL = s.DocURL
}
}
p := &Profile{
SampleType: make([]*ValueType, len(srcs[0].SampleType)),
DropFrames: srcs[0].DropFrames,
KeepFrames: srcs[0].KeepFrames,
TimeNanos: timeNanos,
DurationNanos: durationNanos,
PeriodType: srcs[0].PeriodType,
Period: period,
Comments: comments,
DefaultSampleType: defaultSampleType,
DocURL: docURL,
}
copy(p.SampleType, srcs[0].SampleType)
return p, nil
}
// compatible determines if two profiles can be compared/merged.
// returns nil if the profiles are compatible; otherwise an error with
// details on the incompatibility.
func (p *Profile) compatible(pb *Profile) error {
if !equalValueType(p.PeriodType, pb.PeriodType) {
return fmt.Errorf("incompatible period types %v and %v", p.PeriodType, pb.PeriodType)
}
if len(p.SampleType) != len(pb.SampleType) {
return fmt.Errorf("incompatible sample types %v and %v", p.SampleType, pb.SampleType)
}
for i := range p.SampleType {
if !equalValueType(p.SampleType[i], pb.SampleType[i]) {
return fmt.Errorf("incompatible sample types %v and %v", p.SampleType, pb.SampleType)
}
}
return nil
}
// equalValueType returns true if the two value types are semantically
// equal. It ignores the internal fields used during encode/decode.
func equalValueType(st1, st2 *ValueType) bool {
return st1.Type == st2.Type && st1.Unit == st2.Unit
}
// locationIDMap is like a map[uint64]*Location, but provides efficiency for
// ids that are densely numbered, which is often the case.
type locationIDMap struct {
dense []*Location // indexed by id for id < len(dense)
sparse map[uint64]*Location // indexed by id for id >= len(dense)
}
func makeLocationIDMap(n int) locationIDMap {
return locationIDMap{
dense: make([]*Location, n),
sparse: map[uint64]*Location{},
}
}
func (lm locationIDMap) get(id uint64) *Location {
if id < uint64(len(lm.dense)) {
return lm.dense[int(id)]
}
return lm.sparse[id]
}
func (lm locationIDMap) set(id uint64, loc *Location) {
if id < uint64(len(lm.dense)) {
lm.dense[id] = loc
return
}
lm.sparse[id] = loc
}
// CompatibilizeSampleTypes makes profiles compatible to be compared/merged. It
// keeps sample types that appear in all profiles only and drops/reorders the
// sample types as necessary.
//
// In the case of sample types order is not the same for given profiles the
// order is derived from the first profile.
//
// Profiles are modified in-place.
//
// It returns an error if the sample type's intersection is empty.
func CompatibilizeSampleTypes(ps []*Profile) error {
sTypes := commonSampleTypes(ps)
if len(sTypes) == 0 {
return fmt.Errorf("profiles have empty common sample type list")
}
for _, p := range ps {
if err := compatibilizeSampleTypes(p, sTypes); err != nil {
return err
}
}
return nil
}
// commonSampleTypes returns sample types that appear in all profiles in the
// order how they ordered in the first profile.
func commonSampleTypes(ps []*Profile) []string {
if len(ps) == 0 {
return nil
}
sTypes := map[string]int{}
for _, p := range ps {
for _, st := range p.SampleType {
sTypes[st.Type]++
}
}
var res []string
for _, st := range ps[0].SampleType {
if sTypes[st.Type] == len(ps) {
res = append(res, st.Type)
}
}
return res
}
// compatibilizeSampleTypes drops sample types that are not present in sTypes
// list and reorder them if needed.
//
// It sets DefaultSampleType to sType[0] if it is not in sType list.
//
// It assumes that all sample types from the sTypes list are present in the
// given profile otherwise it returns an error.
func compatibilizeSampleTypes(p *Profile, sTypes []string) error {
if len(sTypes) == 0 {
return fmt.Errorf("sample type list is empty")
}
defaultSampleType := sTypes[0]
reMap, needToModify := make([]int, len(sTypes)), false
for i, st := range sTypes {
if st == p.DefaultSampleType {
defaultSampleType = p.DefaultSampleType
}
idx := searchValueType(p.SampleType, st)
if idx < 0 {
return fmt.Errorf("%q sample type is not found in profile", st)
}
reMap[i] = idx
if idx != i {
needToModify = true
}
}
if !needToModify && len(sTypes) == len(p.SampleType) {
return nil
}
p.DefaultSampleType = defaultSampleType
oldSampleTypes := p.SampleType
p.SampleType = make([]*ValueType, len(sTypes))
for i, idx := range reMap {
p.SampleType[i] = oldSampleTypes[idx]
}
values := make([]int64, len(sTypes))
for _, s := range p.Sample {
for i, idx := range reMap {
values[i] = s.Value[idx]
}
s.Value = s.Value[:len(values)]
copy(s.Value, values)
}
return nil
}
func searchValueType(vts []*ValueType, s string) int {
for i, vt := range vts {
if vt.Type == s {
return i
}
}
return -1
}
-869
View File
@@ -1,869 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package profile provides a representation of profile.proto and
// methods to encode/decode profiles in this format.
package profile
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"math"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
)
// Profile is an in-memory representation of profile.proto.
type Profile struct {
SampleType []*ValueType
DefaultSampleType string
Sample []*Sample
Mapping []*Mapping
Location []*Location
Function []*Function
Comments []string
DocURL string
DropFrames string
KeepFrames string
TimeNanos int64
DurationNanos int64
PeriodType *ValueType
Period int64
// The following fields are modified during encoding and copying,
// so are protected by a Mutex.
encodeMu sync.Mutex
commentX []int64
docURLX int64
dropFramesX int64
keepFramesX int64
stringTable []string
defaultSampleTypeX int64
}
// ValueType corresponds to Profile.ValueType
type ValueType struct {
Type string // cpu, wall, inuse_space, etc
Unit string // seconds, nanoseconds, bytes, etc
typeX int64
unitX int64
}
// Sample corresponds to Profile.Sample
type Sample struct {
Location []*Location
Value []int64
// Label is a per-label-key map to values for string labels.
//
// In general, having multiple values for the given label key is strongly
// discouraged - see docs for the sample label field in profile.proto. The
// main reason this unlikely state is tracked here is to make the
// decoding->encoding roundtrip not lossy. But we expect that the value
// slices present in this map are always of length 1.
Label map[string][]string
// NumLabel is a per-label-key map to values for numeric labels. See a note
// above on handling multiple values for a label.
NumLabel map[string][]int64
// NumUnit is a per-label-key map to the unit names of corresponding numeric
// label values. The unit info may be missing even if the label is in
// NumLabel, see the docs in profile.proto for details. When the value is
// slice is present and not nil, its length must be equal to the length of
// the corresponding value slice in NumLabel.
NumUnit map[string][]string
locationIDX []uint64
labelX []label
}
// label corresponds to Profile.Label
type label struct {
keyX int64
// Exactly one of the two following values must be set
strX int64
numX int64 // Integer value for this label
// can be set if numX has value
unitX int64
}
// Mapping corresponds to Profile.Mapping
type Mapping struct {
ID uint64
Start uint64
Limit uint64
Offset uint64
File string
BuildID string
HasFunctions bool
HasFilenames bool
HasLineNumbers bool
HasInlineFrames bool
fileX int64
buildIDX int64
// Name of the kernel relocation symbol ("_text" or "_stext"), extracted from File.
// For linux kernel mappings generated by some tools, correct symbolization depends
// on knowing which of the two possible relocation symbols was used for `Start`.
// This is given to us as a suffix in `File` (e.g. "[kernel.kallsyms]_stext").
//
// Note, this public field is not persisted in the proto. For the purposes of
// copying / merging / hashing profiles, it is considered subsumed by `File`.
KernelRelocationSymbol string
}
// Location corresponds to Profile.Location
type Location struct {
ID uint64
Mapping *Mapping
Address uint64
Line []Line
IsFolded bool
mappingIDX uint64
}
// Line corresponds to Profile.Line
type Line struct {
Function *Function
Line int64
Column int64
functionIDX uint64
}
// Function corresponds to Profile.Function
type Function struct {
ID uint64
Name string
SystemName string
Filename string
StartLine int64
nameX int64
systemNameX int64
filenameX int64
}
// Parse parses a profile and checks for its validity. The input
// may be a gzip-compressed encoded protobuf or one of many legacy
// profile formats which may be unsupported in the future.
func Parse(r io.Reader) (*Profile, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
return ParseData(data)
}
// ParseData parses a profile from a buffer and checks for its
// validity.
func ParseData(data []byte) (*Profile, error) {
var p *Profile
var err error
if len(data) >= 2 && data[0] == 0x1f && data[1] == 0x8b {
gz, err := gzip.NewReader(bytes.NewBuffer(data))
if err == nil {
data, err = io.ReadAll(gz)
}
if err != nil {
return nil, fmt.Errorf("decompressing profile: %v", err)
}
}
if p, err = ParseUncompressed(data); err != nil && err != errNoData && err != errConcatProfile {
p, err = parseLegacy(data)
}
if err != nil {
return nil, fmt.Errorf("parsing profile: %v", err)
}
if err := p.CheckValid(); err != nil {
return nil, fmt.Errorf("malformed profile: %v", err)
}
return p, nil
}
var errUnrecognized = fmt.Errorf("unrecognized profile format")
var errMalformed = fmt.Errorf("malformed profile format")
var errNoData = fmt.Errorf("empty input file")
var errConcatProfile = fmt.Errorf("concatenated profiles detected")
func parseLegacy(data []byte) (*Profile, error) {
parsers := []func([]byte) (*Profile, error){
parseCPU,
parseHeap,
parseGoCount, // goroutine, threadcreate
parseThread,
parseContention,
parseJavaProfile,
}
for _, parser := range parsers {
p, err := parser(data)
if err == nil {
p.addLegacyFrameInfo()
return p, nil
}
if err != errUnrecognized {
return nil, err
}
}
return nil, errUnrecognized
}
// ParseUncompressed parses an uncompressed protobuf into a profile.
func ParseUncompressed(data []byte) (*Profile, error) {
if len(data) == 0 {
return nil, errNoData
}
p := &Profile{}
if err := unmarshal(data, p); err != nil {
return nil, err
}
if err := p.postDecode(); err != nil {
return nil, err
}
return p, nil
}
var libRx = regexp.MustCompile(`([.]so$|[.]so[._][0-9]+)`)
// massageMappings applies heuristic-based changes to the profile
// mappings to account for quirks of some environments.
func (p *Profile) massageMappings() {
// Merge adjacent regions with matching names, checking that the offsets match
if len(p.Mapping) > 1 {
mappings := []*Mapping{p.Mapping[0]}
for _, m := range p.Mapping[1:] {
lm := mappings[len(mappings)-1]
if adjacent(lm, m) {
lm.Limit = m.Limit
if m.File != "" {
lm.File = m.File
}
if m.BuildID != "" {
lm.BuildID = m.BuildID
}
p.updateLocationMapping(m, lm)
continue
}
mappings = append(mappings, m)
}
p.Mapping = mappings
}
// Use heuristics to identify main binary and move it to the top of the list of mappings
for i, m := range p.Mapping {
file := strings.TrimSpace(strings.Replace(m.File, "(deleted)", "", -1))
if len(file) == 0 {
continue
}
if len(libRx.FindStringSubmatch(file)) > 0 {
continue
}
if file[0] == '[' {
continue
}
// Swap what we guess is main to position 0.
p.Mapping[0], p.Mapping[i] = p.Mapping[i], p.Mapping[0]
break
}
// Keep the mapping IDs neatly sorted
for i, m := range p.Mapping {
m.ID = uint64(i + 1)
}
}
// adjacent returns whether two mapping entries represent the same
// mapping that has been split into two. Check that their addresses are adjacent,
// and if the offsets match, if they are available.
func adjacent(m1, m2 *Mapping) bool {
if m1.File != "" && m2.File != "" {
if m1.File != m2.File {
return false
}
}
if m1.BuildID != "" && m2.BuildID != "" {
if m1.BuildID != m2.BuildID {
return false
}
}
if m1.Limit != m2.Start {
return false
}
if m1.Offset != 0 && m2.Offset != 0 {
offset := m1.Offset + (m1.Limit - m1.Start)
if offset != m2.Offset {
return false
}
}
return true
}
func (p *Profile) updateLocationMapping(from, to *Mapping) {
for _, l := range p.Location {
if l.Mapping == from {
l.Mapping = to
}
}
}
func serialize(p *Profile) []byte {
p.encodeMu.Lock()
p.preEncode()
b := marshal(p)
p.encodeMu.Unlock()
return b
}
// Write writes the profile as a gzip-compressed marshaled protobuf.
func (p *Profile) Write(w io.Writer) error {
zw := gzip.NewWriter(w)
defer zw.Close()
_, err := zw.Write(serialize(p))
return err
}
// WriteUncompressed writes the profile as a marshaled protobuf.
func (p *Profile) WriteUncompressed(w io.Writer) error {
_, err := w.Write(serialize(p))
return err
}
// CheckValid tests whether the profile is valid. Checks include, but are
// not limited to:
// - len(Profile.Sample[n].value) == len(Profile.value_unit)
// - Sample.id has a corresponding Profile.Location
func (p *Profile) CheckValid() error {
// Check that sample values are consistent
sampleLen := len(p.SampleType)
if sampleLen == 0 && len(p.Sample) != 0 {
return fmt.Errorf("missing sample type information")
}
for _, s := range p.Sample {
if s == nil {
return fmt.Errorf("profile has nil sample")
}
if len(s.Value) != sampleLen {
return fmt.Errorf("mismatch: sample has %d values vs. %d types", len(s.Value), len(p.SampleType))
}
for _, l := range s.Location {
if l == nil {
return fmt.Errorf("sample has nil location")
}
}
}
// Check that all mappings/locations/functions are in the tables
// Check that there are no duplicate ids
mappings := make(map[uint64]*Mapping, len(p.Mapping))
for _, m := range p.Mapping {
if m == nil {
return fmt.Errorf("profile has nil mapping")
}
if m.ID == 0 {
return fmt.Errorf("found mapping with reserved ID=0")
}
if mappings[m.ID] != nil {
return fmt.Errorf("multiple mappings with same id: %d", m.ID)
}
mappings[m.ID] = m
}
functions := make(map[uint64]*Function, len(p.Function))
for _, f := range p.Function {
if f == nil {
return fmt.Errorf("profile has nil function")
}
if f.ID == 0 {
return fmt.Errorf("found function with reserved ID=0")
}
if functions[f.ID] != nil {
return fmt.Errorf("multiple functions with same id: %d", f.ID)
}
functions[f.ID] = f
}
locations := make(map[uint64]*Location, len(p.Location))
for _, l := range p.Location {
if l == nil {
return fmt.Errorf("profile has nil location")
}
if l.ID == 0 {
return fmt.Errorf("found location with reserved id=0")
}
if locations[l.ID] != nil {
return fmt.Errorf("multiple locations with same id: %d", l.ID)
}
locations[l.ID] = l
if m := l.Mapping; m != nil {
if m.ID == 0 || mappings[m.ID] != m {
return fmt.Errorf("inconsistent mapping %p: %d", m, m.ID)
}
}
for _, ln := range l.Line {
f := ln.Function
if f == nil {
return fmt.Errorf("location id: %d has a line with nil function", l.ID)
}
if f.ID == 0 || functions[f.ID] != f {
return fmt.Errorf("inconsistent function %p: %d", f, f.ID)
}
}
}
return nil
}
// Aggregate merges the locations in the profile into equivalence
// classes preserving the request attributes. It also updates the
// samples to point to the merged locations.
func (p *Profile) Aggregate(inlineFrame, function, filename, linenumber, columnnumber, address bool) error {
for _, m := range p.Mapping {
m.HasInlineFrames = m.HasInlineFrames && inlineFrame
m.HasFunctions = m.HasFunctions && function
m.HasFilenames = m.HasFilenames && filename
m.HasLineNumbers = m.HasLineNumbers && linenumber
}
// Aggregate functions
if !function || !filename {
for _, f := range p.Function {
if !function {
f.Name = ""
f.SystemName = ""
}
if !filename {
f.Filename = ""
}
}
}
// Aggregate locations
if !inlineFrame || !address || !linenumber || !columnnumber {
for _, l := range p.Location {
if !inlineFrame && len(l.Line) > 1 {
l.Line = l.Line[len(l.Line)-1:]
}
if !linenumber {
for i := range l.Line {
l.Line[i].Line = 0
l.Line[i].Column = 0
}
}
if !columnnumber {
for i := range l.Line {
l.Line[i].Column = 0
}
}
if !address {
l.Address = 0
}
}
}
return p.CheckValid()
}
// NumLabelUnits returns a map of numeric label keys to the units
// associated with those keys and a map of those keys to any units
// that were encountered but not used.
// Unit for a given key is the first encountered unit for that key. If multiple
// units are encountered for values paired with a particular key, then the first
// unit encountered is used and all other units are returned in sorted order
// in map of ignored units.
// If no units are encountered for a particular key, the unit is then inferred
// based on the key.
func (p *Profile) NumLabelUnits() (map[string]string, map[string][]string) {
numLabelUnits := map[string]string{}
ignoredUnits := map[string]map[string]bool{}
encounteredKeys := map[string]bool{}
// Determine units based on numeric tags for each sample.
for _, s := range p.Sample {
for k := range s.NumLabel {
encounteredKeys[k] = true
for _, unit := range s.NumUnit[k] {
if unit == "" {
continue
}
if wantUnit, ok := numLabelUnits[k]; !ok {
numLabelUnits[k] = unit
} else if wantUnit != unit {
if v, ok := ignoredUnits[k]; ok {
v[unit] = true
} else {
ignoredUnits[k] = map[string]bool{unit: true}
}
}
}
}
}
// Infer units for keys without any units associated with
// numeric tag values.
for key := range encounteredKeys {
unit := numLabelUnits[key]
if unit == "" {
switch key {
case "alignment", "request":
numLabelUnits[key] = "bytes"
default:
numLabelUnits[key] = key
}
}
}
// Copy ignored units into more readable format
unitsIgnored := make(map[string][]string, len(ignoredUnits))
for key, values := range ignoredUnits {
units := make([]string, len(values))
i := 0
for unit := range values {
units[i] = unit
i++
}
sort.Strings(units)
unitsIgnored[key] = units
}
return numLabelUnits, unitsIgnored
}
// String dumps a text representation of a profile. Intended mainly
// for debugging purposes.
func (p *Profile) String() string {
ss := make([]string, 0, len(p.Comments)+len(p.Sample)+len(p.Mapping)+len(p.Location))
for _, c := range p.Comments {
ss = append(ss, "Comment: "+c)
}
if url := p.DocURL; url != "" {
ss = append(ss, fmt.Sprintf("Doc: %s", url))
}
if pt := p.PeriodType; pt != nil {
ss = append(ss, fmt.Sprintf("PeriodType: %s %s", pt.Type, pt.Unit))
}
ss = append(ss, fmt.Sprintf("Period: %d", p.Period))
if p.TimeNanos != 0 {
ss = append(ss, fmt.Sprintf("Time: %v", time.Unix(0, p.TimeNanos)))
}
if p.DurationNanos != 0 {
ss = append(ss, fmt.Sprintf("Duration: %.4v", time.Duration(p.DurationNanos)))
}
ss = append(ss, "Samples:")
var sh1 string
for _, s := range p.SampleType {
dflt := ""
if s.Type == p.DefaultSampleType {
dflt = "[dflt]"
}
sh1 = sh1 + fmt.Sprintf("%s/%s%s ", s.Type, s.Unit, dflt)
}
ss = append(ss, strings.TrimSpace(sh1))
for _, s := range p.Sample {
ss = append(ss, s.string())
}
ss = append(ss, "Locations")
for _, l := range p.Location {
ss = append(ss, l.string())
}
ss = append(ss, "Mappings")
for _, m := range p.Mapping {
ss = append(ss, m.string())
}
return strings.Join(ss, "\n") + "\n"
}
// string dumps a text representation of a mapping. Intended mainly
// for debugging purposes.
func (m *Mapping) string() string {
bits := ""
if m.HasFunctions {
bits = bits + "[FN]"
}
if m.HasFilenames {
bits = bits + "[FL]"
}
if m.HasLineNumbers {
bits = bits + "[LN]"
}
if m.HasInlineFrames {
bits = bits + "[IN]"
}
return fmt.Sprintf("%d: %#x/%#x/%#x %s %s %s",
m.ID,
m.Start, m.Limit, m.Offset,
m.File,
m.BuildID,
bits)
}
// string dumps a text representation of a location. Intended mainly
// for debugging purposes.
func (l *Location) string() string {
ss := []string{}
locStr := fmt.Sprintf("%6d: %#x ", l.ID, l.Address)
if m := l.Mapping; m != nil {
locStr = locStr + fmt.Sprintf("M=%d ", m.ID)
}
if l.IsFolded {
locStr = locStr + "[F] "
}
if len(l.Line) == 0 {
ss = append(ss, locStr)
}
for li := range l.Line {
lnStr := "??"
if fn := l.Line[li].Function; fn != nil {
lnStr = fmt.Sprintf("%s %s:%d:%d s=%d",
fn.Name,
fn.Filename,
l.Line[li].Line,
l.Line[li].Column,
fn.StartLine)
if fn.Name != fn.SystemName {
lnStr = lnStr + "(" + fn.SystemName + ")"
}
}
ss = append(ss, locStr+lnStr)
// Do not print location details past the first line
locStr = " "
}
return strings.Join(ss, "\n")
}
// string dumps a text representation of a sample. Intended mainly
// for debugging purposes.
func (s *Sample) string() string {
ss := []string{}
var sv string
for _, v := range s.Value {
sv = fmt.Sprintf("%s %10d", sv, v)
}
sv = sv + ": "
for _, l := range s.Location {
sv = sv + fmt.Sprintf("%d ", l.ID)
}
ss = append(ss, sv)
const labelHeader = " "
if len(s.Label) > 0 {
ss = append(ss, labelHeader+labelsToString(s.Label))
}
if len(s.NumLabel) > 0 {
ss = append(ss, labelHeader+numLabelsToString(s.NumLabel, s.NumUnit))
}
return strings.Join(ss, "\n")
}
// labelsToString returns a string representation of a
// map representing labels.
func labelsToString(labels map[string][]string) string {
ls := []string{}
for k, v := range labels {
ls = append(ls, fmt.Sprintf("%s:%v", k, v))
}
sort.Strings(ls)
return strings.Join(ls, " ")
}
// numLabelsToString returns a string representation of a map
// representing numeric labels.
func numLabelsToString(numLabels map[string][]int64, numUnits map[string][]string) string {
ls := []string{}
for k, v := range numLabels {
units := numUnits[k]
var labelString string
if len(units) == len(v) {
values := make([]string, len(v))
for i, vv := range v {
values[i] = fmt.Sprintf("%d %s", vv, units[i])
}
labelString = fmt.Sprintf("%s:%v", k, values)
} else {
labelString = fmt.Sprintf("%s:%v", k, v)
}
ls = append(ls, labelString)
}
sort.Strings(ls)
return strings.Join(ls, " ")
}
// SetLabel sets the specified key to the specified value for all samples in the
// profile.
func (p *Profile) SetLabel(key string, value []string) {
for _, sample := range p.Sample {
if sample.Label == nil {
sample.Label = map[string][]string{key: value}
} else {
sample.Label[key] = value
}
}
}
// RemoveLabel removes all labels associated with the specified key for all
// samples in the profile.
func (p *Profile) RemoveLabel(key string) {
for _, sample := range p.Sample {
delete(sample.Label, key)
}
}
// HasLabel returns true if a sample has a label with indicated key and value.
func (s *Sample) HasLabel(key, value string) bool {
for _, v := range s.Label[key] {
if v == value {
return true
}
}
return false
}
// SetNumLabel sets the specified key to the specified value for all samples in the
// profile. "unit" is a slice that describes the units that each corresponding member
// of "values" is measured in (e.g. bytes or seconds). If there is no relevant
// unit for a given value, that member of "unit" should be the empty string.
// "unit" must either have the same length as "value", or be nil.
func (p *Profile) SetNumLabel(key string, value []int64, unit []string) {
for _, sample := range p.Sample {
if sample.NumLabel == nil {
sample.NumLabel = map[string][]int64{key: value}
} else {
sample.NumLabel[key] = value
}
if sample.NumUnit == nil {
sample.NumUnit = map[string][]string{key: unit}
} else {
sample.NumUnit[key] = unit
}
}
}
// RemoveNumLabel removes all numerical labels associated with the specified key for all
// samples in the profile.
func (p *Profile) RemoveNumLabel(key string) {
for _, sample := range p.Sample {
delete(sample.NumLabel, key)
delete(sample.NumUnit, key)
}
}
// DiffBaseSample returns true if a sample belongs to the diff base and false
// otherwise.
func (s *Sample) DiffBaseSample() bool {
return s.HasLabel("pprof::base", "true")
}
// Scale multiplies all sample values in a profile by a constant and keeps
// only samples that have at least one non-zero value.
func (p *Profile) Scale(ratio float64) {
if ratio == 1 {
return
}
ratios := make([]float64, len(p.SampleType))
for i := range p.SampleType {
ratios[i] = ratio
}
p.ScaleN(ratios)
}
// ScaleN multiplies each sample values in a sample by a different amount
// and keeps only samples that have at least one non-zero value.
func (p *Profile) ScaleN(ratios []float64) error {
if len(p.SampleType) != len(ratios) {
return fmt.Errorf("mismatched scale ratios, got %d, want %d", len(ratios), len(p.SampleType))
}
allOnes := true
for _, r := range ratios {
if r != 1 {
allOnes = false
break
}
}
if allOnes {
return nil
}
fillIdx := 0
for _, s := range p.Sample {
keepSample := false
for i, v := range s.Value {
if ratios[i] != 1 {
val := int64(math.Round(float64(v) * ratios[i]))
s.Value[i] = val
keepSample = keepSample || val != 0
}
}
if keepSample {
p.Sample[fillIdx] = s
fillIdx++
}
}
p.Sample = p.Sample[:fillIdx]
return nil
}
// HasFunctions determines if all locations in this profile have
// symbolized function information.
func (p *Profile) HasFunctions() bool {
for _, l := range p.Location {
if l.Mapping != nil && !l.Mapping.HasFunctions {
return false
}
}
return true
}
// HasFileLines determines if all locations in this profile have
// symbolized file and line number information.
func (p *Profile) HasFileLines() bool {
for _, l := range p.Location {
if l.Mapping != nil && (!l.Mapping.HasFilenames || !l.Mapping.HasLineNumbers) {
return false
}
}
return true
}
// Unsymbolizable returns true if a mapping points to a binary for which
// locations can't be symbolized in principle, at least now. Examples are
// "[vdso]", "[vsyscall]" and some others, see the code.
func (m *Mapping) Unsymbolizable() bool {
name := filepath.Base(m.File)
return strings.HasPrefix(name, "[") || strings.HasPrefix(name, "linux-vdso") || strings.HasPrefix(m.File, "/dev/dri/") || m.File == "//anon"
}
// Copy makes a fully independent copy of a profile.
func (p *Profile) Copy() *Profile {
pp := &Profile{}
if err := unmarshal(serialize(p), pp); err != nil {
panic(err)
}
if err := pp.postDecode(); err != nil {
panic(err)
}
return pp
}
-367
View File
@@ -1,367 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file is a simple protocol buffer encoder and decoder.
// The format is described at
// https://developers.google.com/protocol-buffers/docs/encoding
//
// A protocol message must implement the message interface:
// decoder() []decoder
// encode(*buffer)
//
// The decode method returns a slice indexed by field number that gives the
// function to decode that field.
// The encode method encodes its receiver into the given buffer.
//
// The two methods are simple enough to be implemented by hand rather than
// by using a protocol compiler.
//
// See profile.go for examples of messages implementing this interface.
//
// There is no support for groups, message sets, or "has" bits.
package profile
import (
"errors"
"fmt"
)
type buffer struct {
field int // field tag
typ int // proto wire type code for field
u64 uint64
data []byte
tmp [16]byte
tmpLines []Line // temporary storage used while decoding "repeated Line".
}
type decoder func(*buffer, message) error
type message interface {
decoder() []decoder
encode(*buffer)
}
func marshal(m message) []byte {
var b buffer
m.encode(&b)
return b.data
}
func encodeVarint(b *buffer, x uint64) {
for x >= 128 {
b.data = append(b.data, byte(x)|0x80)
x >>= 7
}
b.data = append(b.data, byte(x))
}
func encodeLength(b *buffer, tag int, len int) {
encodeVarint(b, uint64(tag)<<3|2)
encodeVarint(b, uint64(len))
}
func encodeUint64(b *buffer, tag int, x uint64) {
// append varint to b.data
encodeVarint(b, uint64(tag)<<3)
encodeVarint(b, x)
}
func encodeUint64s(b *buffer, tag int, x []uint64) {
if len(x) > 2 {
// Use packed encoding
n1 := len(b.data)
for _, u := range x {
encodeVarint(b, u)
}
n2 := len(b.data)
encodeLength(b, tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
return
}
for _, u := range x {
encodeUint64(b, tag, u)
}
}
func encodeUint64Opt(b *buffer, tag int, x uint64) {
if x == 0 {
return
}
encodeUint64(b, tag, x)
}
func encodeInt64(b *buffer, tag int, x int64) {
u := uint64(x)
encodeUint64(b, tag, u)
}
func encodeInt64s(b *buffer, tag int, x []int64) {
if len(x) > 2 {
// Use packed encoding
n1 := len(b.data)
for _, u := range x {
encodeVarint(b, uint64(u))
}
n2 := len(b.data)
encodeLength(b, tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
return
}
for _, u := range x {
encodeInt64(b, tag, u)
}
}
func encodeInt64Opt(b *buffer, tag int, x int64) {
if x == 0 {
return
}
encodeInt64(b, tag, x)
}
func encodeString(b *buffer, tag int, x string) {
encodeLength(b, tag, len(x))
b.data = append(b.data, x...)
}
func encodeStrings(b *buffer, tag int, x []string) {
for _, s := range x {
encodeString(b, tag, s)
}
}
func encodeBool(b *buffer, tag int, x bool) {
if x {
encodeUint64(b, tag, 1)
} else {
encodeUint64(b, tag, 0)
}
}
func encodeBoolOpt(b *buffer, tag int, x bool) {
if x {
encodeBool(b, tag, x)
}
}
func encodeMessage(b *buffer, tag int, m message) {
n1 := len(b.data)
m.encode(b)
n2 := len(b.data)
encodeLength(b, tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
}
func unmarshal(data []byte, m message) (err error) {
b := buffer{data: data, typ: 2}
return decodeMessage(&b, m)
}
func le64(p []byte) uint64 {
return uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 | uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
}
func le32(p []byte) uint32 {
return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
}
func decodeVarint(data []byte) (uint64, []byte, error) {
var u uint64
for i := 0; ; i++ {
if i >= 10 || i >= len(data) {
return 0, nil, errors.New("bad varint")
}
u |= uint64(data[i]&0x7F) << uint(7*i)
if data[i]&0x80 == 0 {
return u, data[i+1:], nil
}
}
}
func decodeField(b *buffer, data []byte) ([]byte, error) {
x, data, err := decodeVarint(data)
if err != nil {
return nil, err
}
b.field = int(x >> 3)
b.typ = int(x & 7)
b.data = nil
b.u64 = 0
switch b.typ {
case 0:
b.u64, data, err = decodeVarint(data)
if err != nil {
return nil, err
}
case 1:
if len(data) < 8 {
return nil, errors.New("not enough data")
}
b.u64 = le64(data[:8])
data = data[8:]
case 2:
var n uint64
n, data, err = decodeVarint(data)
if err != nil {
return nil, err
}
if n > uint64(len(data)) {
return nil, errors.New("too much data")
}
b.data = data[:n]
data = data[n:]
case 5:
if len(data) < 4 {
return nil, errors.New("not enough data")
}
b.u64 = uint64(le32(data[:4]))
data = data[4:]
default:
return nil, fmt.Errorf("unknown wire type: %d", b.typ)
}
return data, nil
}
func checkType(b *buffer, typ int) error {
if b.typ != typ {
return errors.New("type mismatch")
}
return nil
}
func decodeMessage(b *buffer, m message) error {
if err := checkType(b, 2); err != nil {
return err
}
dec := m.decoder()
data := b.data
for len(data) > 0 {
// pull varint field# + type
var err error
data, err = decodeField(b, data)
if err != nil {
return err
}
if b.field >= len(dec) || dec[b.field] == nil {
continue
}
if err := dec[b.field](b, m); err != nil {
return err
}
}
return nil
}
func decodeInt64(b *buffer, x *int64) error {
if err := checkType(b, 0); err != nil {
return err
}
*x = int64(b.u64)
return nil
}
func decodeInt64s(b *buffer, x *[]int64) error {
if b.typ == 2 {
// Packed encoding
data := b.data
for len(data) > 0 {
var u uint64
var err error
if u, data, err = decodeVarint(data); err != nil {
return err
}
*x = append(*x, int64(u))
}
return nil
}
var i int64
if err := decodeInt64(b, &i); err != nil {
return err
}
*x = append(*x, i)
return nil
}
func decodeUint64(b *buffer, x *uint64) error {
if err := checkType(b, 0); err != nil {
return err
}
*x = b.u64
return nil
}
func decodeUint64s(b *buffer, x *[]uint64) error {
if b.typ == 2 {
data := b.data
// Packed encoding
for len(data) > 0 {
var u uint64
var err error
if u, data, err = decodeVarint(data); err != nil {
return err
}
*x = append(*x, u)
}
return nil
}
var u uint64
if err := decodeUint64(b, &u); err != nil {
return err
}
*x = append(*x, u)
return nil
}
func decodeString(b *buffer, x *string) error {
if err := checkType(b, 2); err != nil {
return err
}
*x = string(b.data)
return nil
}
func decodeStrings(b *buffer, x *[]string) error {
var s string
if err := decodeString(b, &s); err != nil {
return err
}
*x = append(*x, s)
return nil
}
func decodeBool(b *buffer, x *bool) error {
if err := checkType(b, 0); err != nil {
return err
}
if int64(b.u64) == 0 {
*x = false
} else {
*x = true
}
return nil
}
-194
View File
@@ -1,194 +0,0 @@
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Implements methods to remove frames from profiles.
package profile
import (
"fmt"
"regexp"
"strings"
)
var (
reservedNames = []string{"(anonymous namespace)", "operator()"}
bracketRx = func() *regexp.Regexp {
var quotedNames []string
for _, name := range append(reservedNames, "(") {
quotedNames = append(quotedNames, regexp.QuoteMeta(name))
}
return regexp.MustCompile(strings.Join(quotedNames, "|"))
}()
)
// simplifyFunc does some primitive simplification of function names.
func simplifyFunc(f string) string {
// Account for leading '.' on the PPC ELF v1 ABI.
funcName := strings.TrimPrefix(f, ".")
// Account for unsimplified names -- try to remove the argument list by trimming
// starting from the first '(', but skipping reserved names that have '('.
for _, ind := range bracketRx.FindAllStringSubmatchIndex(funcName, -1) {
foundReserved := false
for _, res := range reservedNames {
if funcName[ind[0]:ind[1]] == res {
foundReserved = true
break
}
}
if !foundReserved {
funcName = funcName[:ind[0]]
break
}
}
return funcName
}
// Prune removes all nodes beneath a node matching dropRx, and not
// matching keepRx. If the root node of a Sample matches, the sample
// will have an empty stack.
func (p *Profile) Prune(dropRx, keepRx *regexp.Regexp) {
prune := make(map[uint64]bool)
pruneBeneath := make(map[uint64]bool)
// simplifyFunc can be expensive, so cache results.
// Note that the same function name can be encountered many times due
// different lines and addresses in the same function.
pruneCache := map[string]bool{} // Map from function to whether or not to prune
pruneFromHere := func(s string) bool {
if r, ok := pruneCache[s]; ok {
return r
}
funcName := simplifyFunc(s)
if dropRx.MatchString(funcName) {
if keepRx == nil || !keepRx.MatchString(funcName) {
pruneCache[s] = true
return true
}
}
pruneCache[s] = false
return false
}
for _, loc := range p.Location {
var i int
for i = len(loc.Line) - 1; i >= 0; i-- {
if fn := loc.Line[i].Function; fn != nil && fn.Name != "" {
if pruneFromHere(fn.Name) {
break
}
}
}
if i >= 0 {
// Found matching entry to prune.
pruneBeneath[loc.ID] = true
// Remove the matching location.
if i == len(loc.Line)-1 {
// Matched the top entry: prune the whole location.
prune[loc.ID] = true
} else {
loc.Line = loc.Line[i+1:]
}
}
}
// Prune locs from each Sample
for _, sample := range p.Sample {
// Scan from the root to the leaves to find the prune location.
// Do not prune frames before the first user frame, to avoid
// pruning everything.
foundUser := false
for i := len(sample.Location) - 1; i >= 0; i-- {
id := sample.Location[i].ID
if !prune[id] && !pruneBeneath[id] {
foundUser = true
continue
}
if !foundUser {
continue
}
if prune[id] {
sample.Location = sample.Location[i+1:]
break
}
if pruneBeneath[id] {
sample.Location = sample.Location[i:]
break
}
}
}
}
// RemoveUninteresting prunes and elides profiles using built-in
// tables of uninteresting function names.
func (p *Profile) RemoveUninteresting() error {
var keep, drop *regexp.Regexp
var err error
if p.DropFrames != "" {
if drop, err = regexp.Compile("^(" + p.DropFrames + ")$"); err != nil {
return fmt.Errorf("failed to compile regexp %s: %v", p.DropFrames, err)
}
if p.KeepFrames != "" {
if keep, err = regexp.Compile("^(" + p.KeepFrames + ")$"); err != nil {
return fmt.Errorf("failed to compile regexp %s: %v", p.KeepFrames, err)
}
}
p.Prune(drop, keep)
}
return nil
}
// PruneFrom removes all nodes beneath the lowest node matching dropRx, not including itself.
//
// Please see the example below to understand this method as well as
// the difference from Prune method.
//
// A sample contains Location of [A,B,C,B,D] where D is the top frame and there's no inline.
//
// PruneFrom(A) returns [A,B,C,B,D] because there's no node beneath A.
// Prune(A, nil) returns [B,C,B,D] by removing A itself.
//
// PruneFrom(B) returns [B,C,B,D] by removing all nodes beneath the first B when scanning from the bottom.
// Prune(B, nil) returns [D] because a matching node is found by scanning from the root.
func (p *Profile) PruneFrom(dropRx *regexp.Regexp) {
pruneBeneath := make(map[uint64]bool)
for _, loc := range p.Location {
for i := 0; i < len(loc.Line); i++ {
if fn := loc.Line[i].Function; fn != nil && fn.Name != "" {
funcName := simplifyFunc(fn.Name)
if dropRx.MatchString(funcName) {
// Found matching entry to prune.
pruneBeneath[loc.ID] = true
loc.Line = loc.Line[i:]
break
}
}
}
}
// Prune locs from each Sample
for _, sample := range p.Sample {
// Scan from the bottom leaf to the root to find the prune location.
for i, loc := range sample.Location {
if pruneBeneath[loc.ID] {
sample.Location = sample.Location[i:]
break
}
}
}
}
+1
View File
@@ -27,6 +27,7 @@ go_library(
"//internal/httprule",
"//utilities",
"@org_golang_google_genproto_googleapis_api//httpbody",
"@org_golang_google_grpc//:grpc",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//grpclog",
"@org_golang_google_grpc//health/grpc_health_v1",
+3 -3
View File
@@ -201,13 +201,13 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM
if timeout != 0 {
ctx, _ = context.WithTimeout(ctx, timeout)
}
if len(pairs) == 0 {
return ctx, nil, nil
}
md := metadata.Pairs(pairs...)
for _, mda := range mux.metadataAnnotators {
md = metadata.Join(md, mda(ctx, req))
}
if len(md) == 0 {
return ctx, nil, nil
}
return ctx, md, nil
}

Some files were not shown because too many files have changed in this diff Show More