diff --git a/connection/quic.go b/connection/quic.go index 229a72ed..0a7e0484 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -13,6 +13,7 @@ import ( "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/connection/dialopts" + cfdquic "github.com/cloudflare/cloudflared/quic" ) var ( @@ -29,7 +30,7 @@ func DialQuic( connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts, -) (quic.Connection, error) { +) (cfdquic.QUICConnection, error) { udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, opts, logger) if err != nil { return nil, err @@ -43,11 +44,7 @@ func DialQuic( } // wrap the session, so that the UDPConn is closed after session is closed. - conn = &wrapCloseableConnQuicConnection{ - conn, - udpConn, - } - return conn, nil + return cfdquic.NewQUICConnection(conn, udpConn) } func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, opts dialopts.DialOpts, logger *zerolog.Logger) (*net.UDPConn, error) { @@ -96,15 +93,3 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.Add return udpConn, err } - -type wrapCloseableConnQuicConnection struct { - quic.Connection - udpConn *net.UDPConn -} - -func (w *wrapCloseableConnQuicConnection) CloseWithError(errorCode quic.ApplicationErrorCode, reason string) error { - err := w.Connection.CloseWithError(errorCode, reason) - _ = w.udpConn.Close() - - return err -} diff --git a/connection/quic_connection.go b/connection/quic_connection.go index 88d1e7a2..6133a022 100644 --- a/connection/quic_connection.go +++ b/connection/quic_connection.go @@ -41,7 +41,7 @@ const ( // quicConnection represents the type that facilitates Proxying via QUIC streams. type quicConnection struct { - conn quic.Connection + conn cfdquic.QUICConnection logger *zerolog.Logger orchestrator Orchestrator datagramHandler DatagramSessionHandler @@ -54,10 +54,10 @@ type quicConnection struct { gracePeriod time.Duration } -// NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic. +// NewTunnelConnection takes a [cfdquic.QUICConnection] to wrap it for use with cloudflared application logic. func NewTunnelConnection( ctx context.Context, - conn quic.Connection, + conn cfdquic.QUICConnection, connIndex uint8, orchestrator Orchestrator, datagramSessionHandler DatagramSessionHandler, @@ -169,7 +169,7 @@ func (q *quicConnection) acceptStream(ctx context.Context) error { func (q *quicConnection) runStream(quicStream quic.Stream) { ctx := quicStream.Context() stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) - defer stream.Close() + defer func() { _ = stream.Close() }() // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream. diff --git a/connection/quic_datagram_v2.go b/connection/quic_datagram_v2.go index e39f06e6..c6dfa170 100644 --- a/connection/quic_datagram_v2.go +++ b/connection/quic_datagram_v2.go @@ -8,9 +8,7 @@ import ( "time" "github.com/google/uuid" - "github.com/pkg/errors" pkgerrors "github.com/pkg/errors" - "github.com/quic-go/quic-go" "github.com/rs/zerolog" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -24,7 +22,6 @@ import ( "github.com/cloudflare/cloudflared/packet" cfdquic "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/tracing" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" ) @@ -34,20 +31,18 @@ const ( demuxChanCapacity = 16 ) -var ( - errInvalidDestinationIP = errors.New("unable to parse destination IP") -) +var errInvalidDestinationIP = pkgerrors.New("unable to parse destination IP") // DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming // connection streams. type DatagramSessionHandler interface { Serve(context.Context) error - pogs.SessionManager + tunnelpogs.SessionManager } type datagramV2Connection struct { - conn quic.Connection + conn cfdquic.QUICConnection index uint8 // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer @@ -69,7 +64,7 @@ type datagramV2Connection struct { } func NewDatagramV2Connection(ctx context.Context, - conn quic.Connection, + conn cfdquic.QUICConnection, originDialer ingress.OriginUDPDialer, icmpRouter ingress.ICMPRouter, index uint8, @@ -166,7 +161,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy) if err != nil { - originProxy.Close() + _ = originProxy.Close() log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session") tracing.EndWithErrorStatus(registerSpan, err) q.flowLimiter.Release() @@ -229,7 +224,7 @@ func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uu } stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) - defer stream.Close() + defer func() { _ = stream.Close() }() rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout) if err != nil { // Log this at debug because this is not an error if session was closed due to lost connection diff --git a/connection/quic_datagram_v3.go b/connection/quic_datagram_v3.go index c41f8977..4b1d4553 100644 --- a/connection/quic_datagram_v3.go +++ b/connection/quic_datagram_v3.go @@ -7,12 +7,12 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/management" - cfdquic "github.com/cloudflare/cloudflared/quic/v3" + cfdquic "github.com/cloudflare/cloudflared/quic" + cfdquicv3 "github.com/cloudflare/cloudflared/quic/v3" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) @@ -22,20 +22,20 @@ var ( ) type datagramV3Connection struct { - conn quic.Connection + conn cfdquic.QUICConnection index uint8 // datagramMuxer mux/demux datagrams from quic connection - datagramMuxer cfdquic.DatagramConn - metrics cfdquic.Metrics + datagramMuxer cfdquicv3.DatagramConn + metrics cfdquicv3.Metrics logger *zerolog.Logger } func NewDatagramV3Connection(ctx context.Context, - conn quic.Connection, - sessionManager cfdquic.SessionManager, + conn cfdquic.QUICConnection, + sessionManager cfdquicv3.SessionManager, icmpRouter ingress.ICMPRouter, index uint8, - metrics cfdquic.Metrics, + metrics cfdquicv3.Metrics, logger *zerolog.Logger, ) DatagramSessionHandler { log := logger. @@ -43,7 +43,7 @@ func NewDatagramV3Connection(ctx context.Context, Int(management.EventTypeKey, int(management.UDP)). Uint8(LogFieldConnIndex, index). Logger() - datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log) + datagramMuxer := cfdquicv3.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log) return &datagramV3Connection{ conn, diff --git a/mocks/mock_resolvers.go b/mocks/mock_resolvers.go index 96c41cea..d562728d 100644 --- a/mocks/mock_resolvers.go +++ b/mocks/mock_resolvers.go @@ -17,12 +17,13 @@ import ( reflect "reflect" time "time" - quic "github.com/quic-go/quic-go" + quic0 "github.com/quic-go/quic-go" zerolog "github.com/rs/zerolog" gomock "go.uber.org/mock/gomock" dialopts "github.com/cloudflare/cloudflared/connection/dialopts" allregions "github.com/cloudflare/cloudflared/edgediscovery/allregions" + quic "github.com/cloudflare/cloudflared/quic" ) // MockDNSResolver is a mock of DNSResolver interface. @@ -176,10 +177,10 @@ func (m *MockQUICDialer) EXPECT() *MockQUICDialerMockRecorder { } // DialQuic mocks base method. -func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.Connection, error) { +func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic0.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.QUICConnection, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DialQuic", ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts) - ret0, _ := ret[0].(quic.Connection) + ret0, _ := ret[0].(quic.QUICConnection) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -197,19 +198,19 @@ type MockQUICDialerDialQuicCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.Connection, arg1 error) *MockQUICDialerDialQuicCall { +func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.QUICConnection, arg1 error) *MockQUICDialerDialQuicCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall { +func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic0.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.QUICConnection, error)) *MockQUICDialerDialQuicCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall { +func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic0.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.QUICConnection, error)) *MockQUICDialerDialQuicCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/prechecks/probes.go b/prechecks/probes.go index 72a07480..e1f277d7 100644 --- a/prechecks/probes.go +++ b/prechecks/probes.go @@ -17,6 +17,7 @@ import ( "github.com/cloudflare/cloudflared/connection" edgedial "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" + cfdquic "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/tlsconfig" ) @@ -96,7 +97,7 @@ func (d *EdgeQUICDialer) DialQuic( connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts, -) (quic.Connection, error) { +) (cfdquic.QUICConnection, error) { return connection.DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts) } diff --git a/prechecks/resolvers.go b/prechecks/resolvers.go index b1372fd0..a3ea0e34 100644 --- a/prechecks/resolvers.go +++ b/prechecks/resolvers.go @@ -11,6 +11,7 @@ import ( "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/connection/dialopts" + cfdquic "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/edgediscovery/allregions" ) @@ -44,7 +45,7 @@ type QUICDialer interface { connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts, - ) (quic.Connection, error) + ) (cfdquic.QUICConnection, error) } // ManagementDialer abstracts the TCP dial to api.cloudflare.com:443 used by diff --git a/quic/datagramv2.go b/quic/datagramv2.go index 27dfe10e..73712d86 100644 --- a/quic/datagramv2.go +++ b/quic/datagramv2.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/pkg/errors" - "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/packet" @@ -51,14 +50,14 @@ func (dm *DatagramMuxerV2) mtu() int { } type DatagramMuxerV2 struct { - session quic.Connection + session QUICConnection logger *zerolog.Logger sessionDemuxChan chan<- *packet.Session packetDemuxChan chan Packet } func NewDatagramMuxerV2( - quicSession quic.Connection, + quicSession QUICConnection, log *zerolog.Logger, sessionDemuxChan chan<- *packet.Session, ) *DatagramMuxerV2 { @@ -110,7 +109,8 @@ func (dm *DatagramMuxerV2) SendPacket(pk Packet) error { return nil } -// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet +// ServeReceive reads datagrams from the QUIC connection and demuxes them +// depending on whether it's a session or packet func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error { for { msg, err := dm.session.ReceiveDatagram(ctx) @@ -144,8 +144,10 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error switch msgType { case DatagramTypeUDP: return dm.handleSession(ctx, msg) - default: + case DatagramTypeIP, DatagramTypeIPWithTrace, DatagramTypeTracingSpan: return dm.handlePacket(ctx, msg, msgType) + default: + return fmt.Errorf("unexpected datagram type %d", msgType) } } @@ -189,8 +191,10 @@ func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType Spans: spans, TracingIdentity: tracingIdentity, } + case DatagramTypeUDP: + return fmt.Errorf("unexpected datagram type %d in handlePacket", msgType) default: - return fmt.Errorf("Unexpected datagram type %d", msgType) + return fmt.Errorf("unexpected datagram type %d", msgType) } select { case <-ctx.Done(): diff --git a/quic/quic_connection.go b/quic/quic_connection.go new file mode 100644 index 00000000..d8253fe1 --- /dev/null +++ b/quic/quic_connection.go @@ -0,0 +1,105 @@ +package quic + +import ( + "context" + "errors" + "io" + "net" + + "github.com/quic-go/quic-go" +) + +// QUICConnection defines the subset of [quic.Connection] methods used by cloudflared. +// Consumers should accept this interface; producers should return [*ConnWithCloser]. +type QUICConnection interface { + AcceptStream(ctx context.Context) (quic.Stream, error) + OpenStream() (quic.Stream, error) + OpenStreamSync(ctx context.Context) (quic.Stream, error) + CloseWithError(code quic.ApplicationErrorCode, reason string) error + Context() context.Context + SendDatagram(payload []byte) error + ReceiveDatagram(ctx context.Context) ([]byte, error) + LocalAddr() net.Addr + RemoteAddr() net.Addr + ConnectionState() quic.ConnectionState +} + +// Compile-time assertion that *ConnWithCloser implements QUICConnection. +var _ QUICConnection = (*ConnWithCloser)(nil) + +var ( + // error returned when the [NewConnWithCloser] is called with a nil conn argument + ErrNilQuicConnection = errors.New("the provided quic connection is nil") + // error returned when the [NewConnWithCloser] is called with a nil closer argument + ErrNilCloser = errors.New("the provided closer is nil") +) + +// ConnWithCloser wraps a [quic.Connection] and an [io.Closer] (typically the +// underlying [*net.UDPConn]). When [CloseWithError] is called the QUIC +// connection is closed first, then the closer is closed deterministically. +// +// A nil conn is only safe for [CloseWithError] (used in tests). All other +// delegated methods will panic on a nil conn. +type ConnWithCloser struct { + conn quic.Connection + closer io.Closer +} + +// NewQUICConnection returns a [*ConnWithCloser] that will close closer after +// the QUIC connection is closed. +func NewQUICConnection(conn quic.Connection, closer io.Closer) (*ConnWithCloser, error) { + if conn == nil { + return nil, ErrNilQuicConnection + } + + if closer == nil { + return nil, ErrNilCloser + } + return &ConnWithCloser{conn: conn, closer: closer}, nil +} + +// CloseWithError closes the QUIC connection and then closes the underlying +// [io.Closer]. If both operations return errors, the errors are joined so that +// the closer error is no longer silently discarded. +func (c *ConnWithCloser) CloseWithError(code quic.ApplicationErrorCode, reason string) error { + connErr := c.conn.CloseWithError(code, reason) + closerErr := c.closer.Close() + + return errors.Join(connErr, closerErr) +} + +func (c *ConnWithCloser) AcceptStream(ctx context.Context) (quic.Stream, error) { + return c.conn.AcceptStream(ctx) +} + +func (c *ConnWithCloser) OpenStream() (quic.Stream, error) { + return c.conn.OpenStream() +} + +func (c *ConnWithCloser) OpenStreamSync(ctx context.Context) (quic.Stream, error) { + return c.conn.OpenStreamSync(ctx) +} + +func (c *ConnWithCloser) Context() context.Context { + return c.conn.Context() +} + +func (c *ConnWithCloser) SendDatagram(payload []byte) error { + return c.conn.SendDatagram(payload) +} + +func (c *ConnWithCloser) ReceiveDatagram(ctx context.Context) ([]byte, error) { + return c.conn.ReceiveDatagram(ctx) +} + +func (c *ConnWithCloser) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *ConnWithCloser) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *ConnWithCloser) ConnectionState() quic.ConnectionState { + return c.conn.ConnectionState() +} diff --git a/quic/quic_connection_test.go b/quic/quic_connection_test.go new file mode 100644 index 00000000..7f208d5d --- /dev/null +++ b/quic/quic_connection_test.go @@ -0,0 +1,108 @@ +package quic + +import ( + "errors" + "testing" + + "github.com/quic-go/quic-go" + "github.com/stretchr/testify/require" +) + +// mockCloser is an [io.Closer] that returns a configurable error. +type mockCloser struct { + closeErr error +} + +func (m *mockCloser) Close() error { + return m.closeErr +} + +// mockQuicConnection is a minimal test double for [quic.Connection]. +type mockQuicConnection struct { + quic.Connection + closeWithErrorErr error +} + +func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, _ string) error { + return m.closeWithErrorErr +} + +func TestNewConnWithCloser_NilConn(t *testing.T) { + t.Parallel() + conn, err := NewQUICConnection(nil, &mockCloser{}) + require.ErrorIs(t, err, ErrNilQuicConnection) + require.Nil(t, conn) +} + +func TestNewConnWithCloser_NilCloser(t *testing.T) { + t.Parallel() + conn, err := NewQUICConnection(&mockQuicConnection{}, nil) + require.ErrorIs(t, err, ErrNilCloser) + require.Nil(t, conn) +} + +func TestNewConnWithCloser_Success(t *testing.T) { + t.Parallel() + qc := &mockQuicConnection{} + cl := &mockCloser{} + conn, err := NewQUICConnection(qc, cl) + require.NoError(t, err) + require.NotNil(t, conn) +} + +func TestConnWithCloser_CloseWithError_BothSucceed(t *testing.T) { + t.Parallel() + qc := &mockQuicConnection{} + cl := &mockCloser{} + conn, err := NewQUICConnection(qc, cl) + require.NoError(t, err) + + err = conn.CloseWithError(0, "test") + require.NoError(t, err) +} + +func TestConnWithCloser_CloseWithError_QuicFails(t *testing.T) { + t.Parallel() + quicErr := errors.New("quic close failed") + qc := &mockQuicConnection{closeWithErrorErr: quicErr} + cl := &mockCloser{} + conn, err := NewQUICConnection(qc, cl) + require.NoError(t, err) + + err = conn.CloseWithError(0, "test") + require.ErrorIs(t, err, quicErr) +} + +func TestConnWithCloser_CloseWithError_CloserFails(t *testing.T) { + t.Parallel() + closerErr := errors.New("closer failed") + qc := &mockQuicConnection{} + cl := &mockCloser{closeErr: closerErr} + conn, err := NewQUICConnection(qc, cl) + require.NoError(t, err) + + err = conn.CloseWithError(0, "test") + require.ErrorIs(t, err, closerErr) +} + +func TestConnWithCloser_CloseWithError_BothFail(t *testing.T) { + t.Parallel() + quicErr := errors.New("quic close failed") + closerErr := errors.New("closer failed") + qc := &mockQuicConnection{closeWithErrorErr: quicErr} + cl := &mockCloser{closeErr: closerErr} + conn, err := NewQUICConnection(qc, cl) + require.NoError(t, err) + + err = conn.CloseWithError(0, "test") + require.ErrorIs(t, err, quicErr) + require.ErrorIs(t, err, closerErr) +} + +// TestConnWithCloser_ImplementsInterface is a runtime assertion that +// *ConnWithCloser satisfies QUICConnection. The compile-time assertion is in +// quic_connection.go. +func TestConnWithCloser_ImplementsInterface(t *testing.T) { + t.Parallel() + var _ QUICConnection = (*ConnWithCloser)(nil) +}