diff --git a/connection/dialopts/dialopts.go b/connection/dialopts/dialopts.go new file mode 100644 index 00000000..390b87ef --- /dev/null +++ b/connection/dialopts/dialopts.go @@ -0,0 +1,9 @@ +package dialopts + +// DialOpts holds the configuration for dialing a QUIC connection. +type DialOpts struct { + // SkipPortReuse skips UDP port reuse. This is useful for probe connections + // that should use a random ephemeral port to avoid interfering with the + // main connection flow. + SkipPortReuse bool +} diff --git a/connection/quic.go b/connection/quic.go index 3109d77f..229a72ed 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -11,6 +11,8 @@ import ( "github.com/quic-go/quic-go" "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/connection/dialopts" ) var ( @@ -26,8 +28,9 @@ func DialQuic( localAddr net.IP, connIndex uint8, logger *zerolog.Logger, + opts dialopts.DialOpts, ) (quic.Connection, error) { - udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, logger) + udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, opts, logger) if err != nil { return nil, err } @@ -35,7 +38,7 @@ func DialQuic( conn, err := quic.Dial(ctx, udpConn, net.UDPAddrFromAddrPort(edgeAddr), tlsConfig, quicConfig) if err != nil { // close the udp server socket in case of error connecting to the edge - udpConn.Close() + _ = udpConn.Close() return nil, &EdgeQuicDialError{Cause: err} } @@ -47,10 +50,7 @@ func DialQuic( return conn, nil } -func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, logger *zerolog.Logger) (*net.UDPConn, error) { - portMapMutex.Lock() - defer portMapMutex.Unlock() - +func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, opts dialopts.DialOpts, logger *zerolog.Logger) (*net.UDPConn, error) { listenNetwork := "udp" // https://github.com/quic-go/quic-go/issues/3793 DF bit cannot be set for dual stack listener ("udp") on macOS, // to set the DF bit properly, the network string needs to be specific to the IP family. @@ -62,15 +62,24 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.Add } } + // Probes skip port reuse entirely to avoid interfering with the main connection flow. + // They use a random ephemeral port for each dial. + if opts.SkipPortReuse { + return net.ListenUDP(listenNetwork, &net.UDPAddr{IP: localIP, Port: 0}) + } + + portMapMutex.Lock() + defer portMapMutex.Unlock() + // if port was not set yet, it will be zero, so bind will randomly allocate one. if port, ok := portForConnIndex[connIndex]; ok { udpConn, err := net.ListenUDP(listenNetwork, &net.UDPAddr{IP: localIP, Port: port}) // if there wasn't an error, or if port was 0 (independently of error or not, just return) if err == nil { return udpConn, nil - } else { - logger.Debug().Err(err).Msgf("Unable to reuse port %d for connIndex %d. Falling back to random allocation.", port, connIndex) } + + logger.Debug().Err(err).Msgf("Unable to reuse port %d for connIndex %d. Falling back to random allocation.", port, connIndex) } // if we reached here, then there was an error or port as not been allocated it. @@ -95,7 +104,7 @@ type wrapCloseableConnQuicConnection struct { func (w *wrapCloseableConnQuicConnection) CloseWithError(errorCode quic.ApplicationErrorCode, reason string) error { err := w.Connection.CloseWithError(errorCode, reason) - w.udpConn.Close() + _ = w.udpConn.Close() return err } diff --git a/connection/quic_connection_test.go b/connection/quic_connection_test.go index 7e444404..2df741a6 100644 --- a/connection/quic_connection_test.go +++ b/connection/quic_connection_test.go @@ -29,6 +29,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/nettest" + "github.com/cloudflare/cloudflared/connection/dialopts" + "github.com/cloudflare/cloudflared/client" "github.com/cloudflare/cloudflared/config" cfdflow "github.com/cloudflare/cloudflared/flow" @@ -156,7 +158,7 @@ func TestQUICServer(t *testing.T) { require.NoError(t, err) udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) require.NoError(t, err) - defer udpListener.Close() + defer func() { _ = udpListener.Close() }() quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16} quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig) require.NoError(t, err) @@ -523,7 +525,7 @@ func TestServeUDPSession(t *testing.T) { require.NoError(t, err) udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) require.NoError(t, err) - defer udpListener.Close() + defer func() { _ = udpListener.Close() }() ctx, cancel := context.WithCancel(t.Context()) @@ -614,7 +616,7 @@ func TestTCPProxy_FlowRateLimited(t *testing.T) { udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) require.NoError(t, err) - defer udpListener.Close() + defer func() { _ = udpListener.Close() }() quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16} quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig) @@ -658,7 +660,7 @@ func TestTCPProxy_FlowRateLimited(t *testing.T) { func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) { logger := zerolog.Nop() - conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger) + conn, err := createUDPConnForConnIndex(0, nil, edgeIP, dialopts.DialOpts{}, &logger) require.NoError(t, err) getPortFunc := func(conn *net.UDPConn) int { @@ -669,24 +671,114 @@ func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPo initialPort := getPortFunc(conn) // close conn - conn.Close() + _ = conn.Close() // should get the same port as before. - conn, err = createUDPConnForConnIndex(0, nil, edgeIP, &logger) + conn, err = createUDPConnForConnIndex(0, nil, edgeIP, dialopts.DialOpts{}, &logger) require.NoError(t, err) require.Equal(t, initialPort, getPortFunc(conn)) // new index, should get a different port - conn1, err := createUDPConnForConnIndex(1, nil, edgeIP, &logger) + conn1, err := createUDPConnForConnIndex(1, nil, edgeIP, dialopts.DialOpts{}, &logger) require.NoError(t, err) require.NotEqual(t, initialPort, getPortFunc(conn1)) // not closing the conn and trying to obtain a new conn for same index should give a different random port - conn, err = createUDPConnForConnIndex(0, nil, edgeIP, &logger) + conn, err = createUDPConnForConnIndex(0, nil, edgeIP, dialopts.DialOpts{}, &logger) require.NoError(t, err) require.NotEqual(t, initialPort, getPortFunc(conn)) } +// TestSkipPortReuse tests that skipPortReuse uses a random ephemeral port for each dial. +func TestSkipPortReuse(t *testing.T) { + t.Parallel() + logger := zerolog.Nop() + edgeIP := netip.MustParseAddrPort("127.0.0.1:0") + + // First dial with skipPortReuse should allocate a random port + conn1, err := createUDPConnForConnIndex(0, nil, edgeIP, dialopts.DialOpts{SkipPortReuse: true}, &logger) + require.NoError(t, err) + port1 := conn1.LocalAddr().(*net.UDPAddr).Port + + // Don't close conn1 yet - keep it open to prevent port reuse + // Second dial with skipPortReuse should allocate a different random port + conn2, err := createUDPConnForConnIndex(0, nil, edgeIP, dialopts.DialOpts{SkipPortReuse: true}, &logger) + require.NoError(t, err) + port2 := conn2.LocalAddr().(*net.UDPAddr).Port + + // Now close both connections + _ = conn1.Close() + _ = conn2.Close() + // With skipPortReuse, ports should be different (random allocation) + require.NotEqual(t, port1, port2, "With skipPortReuse, each dial should use a different random port") +} + +// TestDialQuicWithSkipPortReuse tests that DialQuic works correctly with the WithSkipPortReuse option. +func TestDialQuicWithSkipPortReuse(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + // Start a mock QUIC server (similar to TestQUICServer) + udpListener, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + require.NoError(t, err) + defer func() { _ = udpListener.Close() }() + + serverAddr := netip.MustParseAddrPort(udpListener.LocalAddr().String()) + + quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16} + quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig) + require.NoError(t, err) + + serverDone := make(chan struct{}) + go func() { + // Accept one connection + session, err := quicListener.Accept(ctx) + if err != nil { + close(serverDone) + return + } + // Keep session open until context is cancelled + <-ctx.Done() + _ = session.CloseWithError(0, "test done") + close(serverDone) + }() + + // Test DialQuic with WithSkipPortReuse option + tlsClientConfig := &tls.Config{ + // nolint: gosec + InsecureSkipVerify: true, + NextProtos: []string{"argotunnel"}, + } + + log := zerolog.New(io.Discard) + dialCtx, dialCancel := context.WithTimeout(t.Context(), 5*time.Second) + defer dialCancel() + + // Dial with skipPortReuse option - should use a random ephemeral port + conn, err := DialQuic( + dialCtx, + testQUICConfig, + tlsClientConfig, + serverAddr, + nil, // connect on a random port + 0, + &log, + dialopts.DialOpts{SkipPortReuse: true}, + ) + require.NoError(t, err) + require.NotNil(t, conn) + + // Verify we can get connection state + _ = conn.ConnectionState() + + // Clean up + _ = conn.CloseWithError(0, "test done") + cancel() + <-serverDone +} + func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) { payload := []byte(t.Name()) sessionID := uuid.New() @@ -719,7 +811,7 @@ func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQ // Close connection to terminate session switch closeType { case closedByOrigin: - originConn.Close() + _ = originConn.Close() case closedByRemote: err = datagramConn.UnregisterUdpSession(ctx, sessionID, expectedReason) require.NoError(t, err) @@ -813,6 +905,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) nil, // connect on a random port index, &log, + dialopts.DialOpts{}, ) require.NoError(t, err) diff --git a/prechecks/interfaces.go b/prechecks/interfaces.go index d8e97b01..28944881 100644 --- a/prechecks/interfaces.go +++ b/prechecks/interfaces.go @@ -10,6 +10,8 @@ import ( "github.com/quic-go/quic-go" "github.com/rs/zerolog" + "github.com/cloudflare/cloudflared/connection/dialopts" + "github.com/cloudflare/cloudflared/edgediscovery/allregions" ) @@ -65,6 +67,7 @@ type QUICDialer interface { localAddr net.IP, connIndex uint8, logger *zerolog.Logger, + opts dialopts.DialOpts, ) (quic.Connection, error) } diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index d672488d..2a2f41c0 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -19,6 +19,7 @@ import ( "github.com/cloudflare/cloudflared/client" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/connection/dialopts" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/features" @@ -129,23 +130,23 @@ type EdgeAddrHandler interface { ShouldGetNewAddress(connIndex uint8, err error) (needsNewAddress bool, connectivityError error) } -func NewIPAddrFallback(maxRetries uint8) *ipAddrFallback { - return &ipAddrFallback{ +func NewIPAddrFallback(maxRetries uint8) *IpAddrFallback { + return &IpAddrFallback{ retriesByConnIndex: make(map[uint8]uint8), maxRetries: maxRetries, } } -// ipAddrFallback will have more conditions to fall back to a new address for certain +// IpAddrFallback will have more conditions to fall back to a new address for certain // edge connection errors. This means that this handler will return true for isConnectivityError // for more cases like duplicate connection register and edge quic dial errors. -type ipAddrFallback struct { +type IpAddrFallback struct { m sync.Mutex retriesByConnIndex map[uint8]uint8 maxRetries uint8 } -func (f *ipAddrFallback) ShouldGetNewAddress(connIndex uint8, err error) (needsNewAddress bool, connectivityError error) { +func (f *IpAddrFallback) ShouldGetNewAddress(connIndex uint8, err error) (needsNewAddress bool, connectivityError error) { f.m.Lock() defer f.m.Unlock() switch err.(type) { @@ -597,6 +598,7 @@ func (e *EdgeTunnelServer) serveQUIC( e.edgeBindAddr, connIndex, connLogger.Logger(), + dialopts.DialOpts{}, ) if err != nil { connLogger.ConnAwareLogger().Err(err).Msgf("Failed to dial a quic connection")