TUN-10563: introduce QUICConnection interface

The bump of the QUIC library introduces a cyclic dependency between the connection and quic modules hence it is necessary to break this coupling.

Right now, the connection module depends on the quic module for the datagram v2/v3 and to which a QUIC connection (currently an interface) is passed.

As it is there is no issue however, under the hood, interface is a wrapper around an UDP connection and a QUIC connection meaning this type must be exposed to the quic module since the QUIC Connection will no longer be a interface but a struct.

Given the above, these changes introduce an interface, QUICConnection, with the surface used today in cloudflared and a struct, ConnWithCloser, that implements said interface within the quic module.

Closes TUN-10563
This commit is contained in:
Luis Neto
2026-06-01 10:08:38 +01:00
parent 0e84636de9
commit 52519f67e8
10 changed files with 256 additions and 56 deletions
+3 -18
View File
@@ -13,6 +13,7 @@ import (
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection/dialopts"
cfdquic "github.com/cloudflare/cloudflared/quic"
)
var (
@@ -29,7 +30,7 @@ func DialQuic(
connIndex uint8,
logger *zerolog.Logger,
opts dialopts.DialOpts,
) (quic.Connection, error) {
) (cfdquic.QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, opts, logger)
if err != nil {
return nil, err
@@ -43,11 +44,7 @@ func DialQuic(
}
// wrap the session, so that the UDPConn is closed after session is closed.
conn = &wrapCloseableConnQuicConnection{
conn,
udpConn,
}
return conn, nil
return cfdquic.NewQUICConnection(conn, udpConn)
}
func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, opts dialopts.DialOpts, logger *zerolog.Logger) (*net.UDPConn, error) {
@@ -96,15 +93,3 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.Add
return udpConn, err
}
type wrapCloseableConnQuicConnection struct {
quic.Connection
udpConn *net.UDPConn
}
func (w *wrapCloseableConnQuicConnection) CloseWithError(errorCode quic.ApplicationErrorCode, reason string) error {
err := w.Connection.CloseWithError(errorCode, reason)
_ = w.udpConn.Close()
return err
}
+4 -4
View File
@@ -41,7 +41,7 @@ const (
// quicConnection represents the type that facilitates Proxying via QUIC streams.
type quicConnection struct {
conn quic.Connection
conn cfdquic.QUICConnection
logger *zerolog.Logger
orchestrator Orchestrator
datagramHandler DatagramSessionHandler
@@ -54,10 +54,10 @@ type quicConnection struct {
gracePeriod time.Duration
}
// NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic.
// NewTunnelConnection takes a [cfdquic.QUICConnection] to wrap it for use with cloudflared application logic.
func NewTunnelConnection(
ctx context.Context,
conn quic.Connection,
conn cfdquic.QUICConnection,
connIndex uint8,
orchestrator Orchestrator,
datagramSessionHandler DatagramSessionHandler,
@@ -169,7 +169,7 @@ func (q *quicConnection) acceptStream(ctx context.Context) error {
func (q *quicConnection) runStream(quicStream quic.Stream) {
ctx := quicStream.Context()
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
defer func() { _ = stream.Close() }()
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
+6 -11
View File
@@ -8,9 +8,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
pkgerrors "github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
@@ -24,7 +22,6 @@ import (
"github.com/cloudflare/cloudflared/packet"
cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
)
@@ -34,20 +31,18 @@ const (
demuxChanCapacity = 16
)
var (
errInvalidDestinationIP = errors.New("unable to parse destination IP")
)
var errInvalidDestinationIP = pkgerrors.New("unable to parse destination IP")
// DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming
// connection streams.
type DatagramSessionHandler interface {
Serve(context.Context) error
pogs.SessionManager
tunnelpogs.SessionManager
}
type datagramV2Connection struct {
conn quic.Connection
conn cfdquic.QUICConnection
index uint8
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
@@ -69,7 +64,7 @@ type datagramV2Connection struct {
}
func NewDatagramV2Connection(ctx context.Context,
conn quic.Connection,
conn cfdquic.QUICConnection,
originDialer ingress.OriginUDPDialer,
icmpRouter ingress.ICMPRouter,
index uint8,
@@ -166,7 +161,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
if err != nil {
originProxy.Close()
_ = originProxy.Close()
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
tracing.EndWithErrorStatus(registerSpan, err)
q.flowLimiter.Release()
@@ -229,7 +224,7 @@ func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uu
}
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
defer func() { _ = stream.Close() }()
rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout)
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
+9 -9
View File
@@ -7,12 +7,12 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
cfdquic "github.com/cloudflare/cloudflared/quic/v3"
cfdquic "github.com/cloudflare/cloudflared/quic"
cfdquicv3 "github.com/cloudflare/cloudflared/quic/v3"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@@ -22,20 +22,20 @@ var (
)
type datagramV3Connection struct {
conn quic.Connection
conn cfdquic.QUICConnection
index uint8
// datagramMuxer mux/demux datagrams from quic connection
datagramMuxer cfdquic.DatagramConn
metrics cfdquic.Metrics
datagramMuxer cfdquicv3.DatagramConn
metrics cfdquicv3.Metrics
logger *zerolog.Logger
}
func NewDatagramV3Connection(ctx context.Context,
conn quic.Connection,
sessionManager cfdquic.SessionManager,
conn cfdquic.QUICConnection,
sessionManager cfdquicv3.SessionManager,
icmpRouter ingress.ICMPRouter,
index uint8,
metrics cfdquic.Metrics,
metrics cfdquicv3.Metrics,
logger *zerolog.Logger,
) DatagramSessionHandler {
log := logger.
@@ -43,7 +43,7 @@ func NewDatagramV3Connection(ctx context.Context,
Int(management.EventTypeKey, int(management.UDP)).
Uint8(LogFieldConnIndex, index).
Logger()
datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log)
datagramMuxer := cfdquicv3.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log)
return &datagramV3Connection{
conn,
+7 -6
View File
@@ -17,12 +17,13 @@ import (
reflect "reflect"
time "time"
quic "github.com/quic-go/quic-go"
quic0 "github.com/quic-go/quic-go"
zerolog "github.com/rs/zerolog"
gomock "go.uber.org/mock/gomock"
dialopts "github.com/cloudflare/cloudflared/connection/dialopts"
allregions "github.com/cloudflare/cloudflared/edgediscovery/allregions"
quic "github.com/cloudflare/cloudflared/quic"
)
// MockDNSResolver is a mock of DNSResolver interface.
@@ -176,10 +177,10 @@ func (m *MockQUICDialer) EXPECT() *MockQUICDialerMockRecorder {
}
// DialQuic mocks base method.
func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.Connection, error) {
func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic0.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.QUICConnection, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialQuic", ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
ret0, _ := ret[0].(quic.Connection)
ret0, _ := ret[0].(quic.QUICConnection)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -197,19 +198,19 @@ type MockQUICDialerDialQuicCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.Connection, arg1 error) *MockQUICDialerDialQuicCall {
func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.QUICConnection, arg1 error) *MockQUICDialerDialQuicCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic0.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.QUICConnection, error)) *MockQUICDialerDialQuicCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic0.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.QUICConnection, error)) *MockQUICDialerDialQuicCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
+2 -1
View File
@@ -17,6 +17,7 @@ import (
"github.com/cloudflare/cloudflared/connection"
edgedial "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tlsconfig"
)
@@ -96,7 +97,7 @@ func (d *EdgeQUICDialer) DialQuic(
connIndex uint8,
logger *zerolog.Logger,
opts dialopts.DialOpts,
) (quic.Connection, error) {
) (cfdquic.QUICConnection, error) {
return connection.DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
}
+2 -1
View File
@@ -11,6 +11,7 @@ import (
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection/dialopts"
cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
)
@@ -44,7 +45,7 @@ type QUICDialer interface {
connIndex uint8,
logger *zerolog.Logger,
opts dialopts.DialOpts,
) (quic.Connection, error)
) (cfdquic.QUICConnection, error)
}
// ManagementDialer abstracts the TCP dial to api.cloudflare.com:443 used by
+10 -6
View File
@@ -5,7 +5,6 @@ import (
"fmt"
"github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/packet"
@@ -51,14 +50,14 @@ func (dm *DatagramMuxerV2) mtu() int {
}
type DatagramMuxerV2 struct {
session quic.Connection
session QUICConnection
logger *zerolog.Logger
sessionDemuxChan chan<- *packet.Session
packetDemuxChan chan Packet
}
func NewDatagramMuxerV2(
quicSession quic.Connection,
quicSession QUICConnection,
log *zerolog.Logger,
sessionDemuxChan chan<- *packet.Session,
) *DatagramMuxerV2 {
@@ -110,7 +109,8 @@ func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
return nil
}
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
// ServeReceive reads datagrams from the QUIC connection and demuxes them
// depending on whether it's a session or packet
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
for {
msg, err := dm.session.ReceiveDatagram(ctx)
@@ -144,8 +144,10 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error
switch msgType {
case DatagramTypeUDP:
return dm.handleSession(ctx, msg)
default:
case DatagramTypeIP, DatagramTypeIPWithTrace, DatagramTypeTracingSpan:
return dm.handlePacket(ctx, msg, msgType)
default:
return fmt.Errorf("unexpected datagram type %d", msgType)
}
}
@@ -189,8 +191,10 @@ func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType
Spans: spans,
TracingIdentity: tracingIdentity,
}
case DatagramTypeUDP:
return fmt.Errorf("unexpected datagram type %d in handlePacket", msgType)
default:
return fmt.Errorf("Unexpected datagram type %d", msgType)
return fmt.Errorf("unexpected datagram type %d", msgType)
}
select {
case <-ctx.Done():
+105
View File
@@ -0,0 +1,105 @@
package quic
import (
"context"
"errors"
"io"
"net"
"github.com/quic-go/quic-go"
)
// QUICConnection defines the subset of [quic.Connection] methods used by cloudflared.
// Consumers should accept this interface; producers should return [*ConnWithCloser].
type QUICConnection interface {
AcceptStream(ctx context.Context) (quic.Stream, error)
OpenStream() (quic.Stream, error)
OpenStreamSync(ctx context.Context) (quic.Stream, error)
CloseWithError(code quic.ApplicationErrorCode, reason string) error
Context() context.Context
SendDatagram(payload []byte) error
ReceiveDatagram(ctx context.Context) ([]byte, error)
LocalAddr() net.Addr
RemoteAddr() net.Addr
ConnectionState() quic.ConnectionState
}
// Compile-time assertion that *ConnWithCloser implements QUICConnection.
var _ QUICConnection = (*ConnWithCloser)(nil)
var (
// error returned when the [NewConnWithCloser] is called with a nil conn argument
ErrNilQuicConnection = errors.New("the provided quic connection is nil")
// error returned when the [NewConnWithCloser] is called with a nil closer argument
ErrNilCloser = errors.New("the provided closer is nil")
)
// ConnWithCloser wraps a [quic.Connection] and an [io.Closer] (typically the
// underlying [*net.UDPConn]). When [CloseWithError] is called the QUIC
// connection is closed first, then the closer is closed deterministically.
//
// A nil conn is only safe for [CloseWithError] (used in tests). All other
// delegated methods will panic on a nil conn.
type ConnWithCloser struct {
conn quic.Connection
closer io.Closer
}
// NewQUICConnection returns a [*ConnWithCloser] that will close closer after
// the QUIC connection is closed.
func NewQUICConnection(conn quic.Connection, closer io.Closer) (*ConnWithCloser, error) {
if conn == nil {
return nil, ErrNilQuicConnection
}
if closer == nil {
return nil, ErrNilCloser
}
return &ConnWithCloser{conn: conn, closer: closer}, nil
}
// CloseWithError closes the QUIC connection and then closes the underlying
// [io.Closer]. If both operations return errors, the errors are joined so that
// the closer error is no longer silently discarded.
func (c *ConnWithCloser) CloseWithError(code quic.ApplicationErrorCode, reason string) error {
connErr := c.conn.CloseWithError(code, reason)
closerErr := c.closer.Close()
return errors.Join(connErr, closerErr)
}
func (c *ConnWithCloser) AcceptStream(ctx context.Context) (quic.Stream, error) {
return c.conn.AcceptStream(ctx)
}
func (c *ConnWithCloser) OpenStream() (quic.Stream, error) {
return c.conn.OpenStream()
}
func (c *ConnWithCloser) OpenStreamSync(ctx context.Context) (quic.Stream, error) {
return c.conn.OpenStreamSync(ctx)
}
func (c *ConnWithCloser) Context() context.Context {
return c.conn.Context()
}
func (c *ConnWithCloser) SendDatagram(payload []byte) error {
return c.conn.SendDatagram(payload)
}
func (c *ConnWithCloser) ReceiveDatagram(ctx context.Context) ([]byte, error) {
return c.conn.ReceiveDatagram(ctx)
}
func (c *ConnWithCloser) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *ConnWithCloser) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *ConnWithCloser) ConnectionState() quic.ConnectionState {
return c.conn.ConnectionState()
}
+108
View File
@@ -0,0 +1,108 @@
package quic
import (
"errors"
"testing"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/require"
)
// mockCloser is an [io.Closer] that returns a configurable error.
type mockCloser struct {
closeErr error
}
func (m *mockCloser) Close() error {
return m.closeErr
}
// mockQuicConnection is a minimal test double for [quic.Connection].
type mockQuicConnection struct {
quic.Connection
closeWithErrorErr error
}
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, _ string) error {
return m.closeWithErrorErr
}
func TestNewConnWithCloser_NilConn(t *testing.T) {
t.Parallel()
conn, err := NewQUICConnection(nil, &mockCloser{})
require.ErrorIs(t, err, ErrNilQuicConnection)
require.Nil(t, conn)
}
func TestNewConnWithCloser_NilCloser(t *testing.T) {
t.Parallel()
conn, err := NewQUICConnection(&mockQuicConnection{}, nil)
require.ErrorIs(t, err, ErrNilCloser)
require.Nil(t, conn)
}
func TestNewConnWithCloser_Success(t *testing.T) {
t.Parallel()
qc := &mockQuicConnection{}
cl := &mockCloser{}
conn, err := NewQUICConnection(qc, cl)
require.NoError(t, err)
require.NotNil(t, conn)
}
func TestConnWithCloser_CloseWithError_BothSucceed(t *testing.T) {
t.Parallel()
qc := &mockQuicConnection{}
cl := &mockCloser{}
conn, err := NewQUICConnection(qc, cl)
require.NoError(t, err)
err = conn.CloseWithError(0, "test")
require.NoError(t, err)
}
func TestConnWithCloser_CloseWithError_QuicFails(t *testing.T) {
t.Parallel()
quicErr := errors.New("quic close failed")
qc := &mockQuicConnection{closeWithErrorErr: quicErr}
cl := &mockCloser{}
conn, err := NewQUICConnection(qc, cl)
require.NoError(t, err)
err = conn.CloseWithError(0, "test")
require.ErrorIs(t, err, quicErr)
}
func TestConnWithCloser_CloseWithError_CloserFails(t *testing.T) {
t.Parallel()
closerErr := errors.New("closer failed")
qc := &mockQuicConnection{}
cl := &mockCloser{closeErr: closerErr}
conn, err := NewQUICConnection(qc, cl)
require.NoError(t, err)
err = conn.CloseWithError(0, "test")
require.ErrorIs(t, err, closerErr)
}
func TestConnWithCloser_CloseWithError_BothFail(t *testing.T) {
t.Parallel()
quicErr := errors.New("quic close failed")
closerErr := errors.New("closer failed")
qc := &mockQuicConnection{closeWithErrorErr: quicErr}
cl := &mockCloser{closeErr: closerErr}
conn, err := NewQUICConnection(qc, cl)
require.NoError(t, err)
err = conn.CloseWithError(0, "test")
require.ErrorIs(t, err, quicErr)
require.ErrorIs(t, err, closerErr)
}
// TestConnWithCloser_ImplementsInterface is a runtime assertion that
// *ConnWithCloser satisfies QUICConnection. The compile-time assertion is in
// quic_connection.go.
func TestConnWithCloser_ImplementsInterface(t *testing.T) {
t.Parallel()
var _ QUICConnection = (*ConnWithCloser)(nil)
}