From 5287a9e24b7ea658e66da039578d53dfd5e5ea93 Mon Sep 17 00:00:00 2001 From: Miguel da Costa Martins Marcelino Date: Tue, 14 Apr 2026 15:35:03 +0000 Subject: [PATCH] TUN-10384: Probe TLS Helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `ProbeTLSSettings` helper to connection/protocol.go that returns new settings with the `probe.cftunnel.com` SNI for pre-checks. --- connection/protocol.go | 22 +++++++++++- connection/protocol_test.go | 67 +++++++++++++++++++++++++++---------- 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/connection/protocol.go b/connection/protocol.go index fd53c105..c3070675 100644 --- a/connection/protocol.go +++ b/connection/protocol.go @@ -19,6 +19,9 @@ const ( edgeH2TLSServerName = "h2.cftunnel.com" // edgeQUICServerName is the server name to establish quic connection with edge. edgeQUICServerName = "quic.cftunnel.com" + // probeTLSServerName is the server name used for pre-flight connectivity checks. + probeTLSServerName = "probe.cftunnel.com" + quicProtos = "argotunnel" AutoSelectFlag = "auto" // SRV and TXT record resolution TTL ResolveTTL = time.Hour @@ -69,7 +72,24 @@ func (p Protocol) TLSSettings() *TLSSettings { case QUIC: return &TLSSettings{ ServerName: edgeQUICServerName, - NextProtos: []string{"argotunnel"}, + NextProtos: []string{quicProtos}, + } + default: + return nil + } +} + +// ProbeTLSSettings returns TLS settings for pre-flight connectivity checks. +func (p Protocol) ProbeTLSSettings() *TLSSettings { + switch p { + case HTTP2: + return &TLSSettings{ + ServerName: probeTLSServerName, + } + case QUIC: + return &TLSSettings{ + ServerName: probeTLSServerName, + NextProtos: []string{quicProtos}, } default: return nil diff --git a/connection/protocol_test.go b/connection/protocol_test.go index 12d238c5..9d897ae0 100644 --- a/connection/protocol_test.go +++ b/connection/protocol_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/edgediscovery" ) @@ -14,15 +15,6 @@ const ( testAccountTag = "testAccountTag" ) -func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) edgediscovery.PercentageFetcher { - return func() (edgediscovery.ProtocolPercents, error) { - if getError { - return nil, fmt.Errorf("failed to fetch percentage") - } - return protocolPercent, nil - } -} - type dynamicMockFetcher struct { protocolPercents edgediscovery.ProtocolPercents err error @@ -89,14 +81,14 @@ func TestNewProtocolSelector(t *testing.T) { t.Run(test.name, func(t *testing.T) { selector, err := NewProtocolSelector(test.protocol, testAccountTag, test.tunnelTokenProvided, test.needPQ, fetcher.fetch(), ResolveTTL, &log) if test.wantErr { - assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) + assert.Error(t, err, "test %s failed", test.name) } else { - assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) - assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) + require.NoError(t, err, "test %s failed", test.name) + assert.Equalf(t, test.expectedProtocol, selector.Current(), "test %s failed", test.name) fallback, ok := selector.Fallback() - assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) + assert.Equalf(t, test.hasFallback, ok, "test %s failed", test.name) if test.hasFallback { - assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) + assert.Equalf(t, test.expectedFallback, fallback, "test %s failed", test.name) } } }) @@ -106,7 +98,7 @@ func TestNewProtocolSelector(t *testing.T) { func TestAutoProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, QUIC, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} @@ -136,7 +128,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} // Since the user chooses http2 on purpose, we always stick to it. selector, err := NewProtocolSelector(HTTP2.String(), testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} @@ -165,9 +157,50 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { func TestAutoProtocolSelectorNoRefreshWithToken(t *testing.T) { fetcher := dynamicMockFetcher{} selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, true, false, fetcher.fetch(), testNoTTL, &log) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, QUIC, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, QUIC, selector.Current()) } + +func TestProbeTLSSettings(t *testing.T) { + tests := []struct { + name string + protocol Protocol + expectedServer string + expectedProtos []string + expectNil bool + }{ + { + name: "HTTP2 returns probe SNI", + protocol: HTTP2, + expectedServer: probeTLSServerName, + expectedProtos: nil, + }, + { + name: "QUIC returns probe SNI with alpn", + protocol: QUIC, + expectedServer: probeTLSServerName, + expectedProtos: []string{"argotunnel"}, + }, + { + name: "Unknown protocol returns nil", + protocol: Protocol(999), + expectNil: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + settings := test.protocol.ProbeTLSSettings() + if test.expectNil { + assert.Nil(t, settings) + } else { + assert.NotNil(t, settings) + assert.Equal(t, test.expectedServer, settings.ServerName) + assert.Equal(t, test.expectedProtos, settings.NextProtos) + } + }) + } +}