TUN-10384: Probe TLS Helper

Add `ProbeTLSSettings` helper to connection/protocol.go that returns new settings with the `probe.cftunnel.com` SNI for pre-checks.
This commit is contained in:
Miguel da Costa Martins Marcelino
2026-04-14 15:35:03 +00:00
parent e2a71cbecc
commit 5287a9e24b
2 changed files with 71 additions and 18 deletions
+21 -1
View File
@@ -19,6 +19,9 @@ const (
edgeH2TLSServerName = "h2.cftunnel.com"
// edgeQUICServerName is the server name to establish quic connection with edge.
edgeQUICServerName = "quic.cftunnel.com"
// probeTLSServerName is the server name used for pre-flight connectivity checks.
probeTLSServerName = "probe.cftunnel.com"
quicProtos = "argotunnel"
AutoSelectFlag = "auto"
// SRV and TXT record resolution TTL
ResolveTTL = time.Hour
@@ -69,7 +72,24 @@ func (p Protocol) TLSSettings() *TLSSettings {
case QUIC:
return &TLSSettings{
ServerName: edgeQUICServerName,
NextProtos: []string{"argotunnel"},
NextProtos: []string{quicProtos},
}
default:
return nil
}
}
// ProbeTLSSettings returns TLS settings for pre-flight connectivity checks.
func (p Protocol) ProbeTLSSettings() *TLSSettings {
switch p {
case HTTP2:
return &TLSSettings{
ServerName: probeTLSServerName,
}
case QUIC:
return &TLSSettings{
ServerName: probeTLSServerName,
NextProtos: []string{quicProtos},
}
default:
return nil
+50 -17
View File
@@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/edgediscovery"
)
@@ -14,15 +15,6 @@ const (
testAccountTag = "testAccountTag"
)
func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) edgediscovery.PercentageFetcher {
return func() (edgediscovery.ProtocolPercents, error) {
if getError {
return nil, fmt.Errorf("failed to fetch percentage")
}
return protocolPercent, nil
}
}
type dynamicMockFetcher struct {
protocolPercents edgediscovery.ProtocolPercents
err error
@@ -89,14 +81,14 @@ func TestNewProtocolSelector(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
selector, err := NewProtocolSelector(test.protocol, testAccountTag, test.tunnelTokenProvided, test.needPQ, fetcher.fetch(), ResolveTTL, &log)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Error(t, err, "test %s failed", test.name)
} else {
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
require.NoError(t, err, "test %s failed", test.name)
assert.Equalf(t, test.expectedProtocol, selector.Current(), "test %s failed", test.name)
fallback, ok := selector.Fallback()
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
assert.Equalf(t, test.hasFallback, ok, "test %s failed", test.name)
if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
assert.Equalf(t, test.expectedFallback, fallback, "test %s failed", test.name)
}
}
})
@@ -106,7 +98,7 @@ func TestNewProtocolSelector(t *testing.T) {
func TestAutoProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, QUIC, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
@@ -136,7 +128,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
// Since the user chooses http2 on purpose, we always stick to it.
selector, err := NewProtocolSelector(HTTP2.String(), testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
@@ -165,9 +157,50 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
func TestAutoProtocolSelectorNoRefreshWithToken(t *testing.T) {
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, true, false, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, QUIC, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, QUIC, selector.Current())
}
func TestProbeTLSSettings(t *testing.T) {
tests := []struct {
name string
protocol Protocol
expectedServer string
expectedProtos []string
expectNil bool
}{
{
name: "HTTP2 returns probe SNI",
protocol: HTTP2,
expectedServer: probeTLSServerName,
expectedProtos: nil,
},
{
name: "QUIC returns probe SNI with alpn",
protocol: QUIC,
expectedServer: probeTLSServerName,
expectedProtos: []string{"argotunnel"},
},
{
name: "Unknown protocol returns nil",
protocol: Protocol(999),
expectNil: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
settings := test.protocol.ProbeTLSSettings()
if test.expectNil {
assert.Nil(t, settings)
} else {
assert.NotNil(t, settings)
assert.Equal(t, test.expectedServer, settings.ServerName)
assert.Equal(t, test.expectedProtos, settings.NextProtos)
}
})
}
}