mirror of
https://github.com/cloudflare/cloudflared.git
synced 2026-06-23 04:10:20 +00:00
TUN-10563: introduce QUICConnection interface
The bump of the QUIC library introduces a cyclic dependency between the connection and quic modules hence it is necessary to break this coupling. Right now, the connection module depends on the quic module for the datagram v2/v3 and to which a QUIC connection (currently an interface) is passed. As it is there is no issue however, under the hood, interface is a wrapper around an UDP connection and a QUIC connection meaning this type must be exposed to the quic module since the QUIC Connection will no longer be a interface but a struct. Given the above, these changes introduce an interface, QUICConnection, with the surface used today in cloudflared and a struct, ConnWithCloser, that implements said interface within the quic module. Closes TUN-10563
This commit is contained in:
+3
-18
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection/dialopts"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -29,7 +30,7 @@ func DialQuic(
|
||||
connIndex uint8,
|
||||
logger *zerolog.Logger,
|
||||
opts dialopts.DialOpts,
|
||||
) (quic.Connection, error) {
|
||||
) (cfdquic.QUICConnection, error) {
|
||||
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, opts, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -43,11 +44,7 @@ func DialQuic(
|
||||
}
|
||||
|
||||
// wrap the session, so that the UDPConn is closed after session is closed.
|
||||
conn = &wrapCloseableConnQuicConnection{
|
||||
conn,
|
||||
udpConn,
|
||||
}
|
||||
return conn, nil
|
||||
return cfdquic.NewQUICConnection(conn, udpConn)
|
||||
}
|
||||
|
||||
func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, opts dialopts.DialOpts, logger *zerolog.Logger) (*net.UDPConn, error) {
|
||||
@@ -96,15 +93,3 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.Add
|
||||
|
||||
return udpConn, err
|
||||
}
|
||||
|
||||
type wrapCloseableConnQuicConnection struct {
|
||||
quic.Connection
|
||||
udpConn *net.UDPConn
|
||||
}
|
||||
|
||||
func (w *wrapCloseableConnQuicConnection) CloseWithError(errorCode quic.ApplicationErrorCode, reason string) error {
|
||||
err := w.Connection.CloseWithError(errorCode, reason)
|
||||
_ = w.udpConn.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
|
||||
// quicConnection represents the type that facilitates Proxying via QUIC streams.
|
||||
type quicConnection struct {
|
||||
conn quic.Connection
|
||||
conn cfdquic.QUICConnection
|
||||
logger *zerolog.Logger
|
||||
orchestrator Orchestrator
|
||||
datagramHandler DatagramSessionHandler
|
||||
@@ -54,10 +54,10 @@ type quicConnection struct {
|
||||
gracePeriod time.Duration
|
||||
}
|
||||
|
||||
// NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic.
|
||||
// NewTunnelConnection takes a [cfdquic.QUICConnection] to wrap it for use with cloudflared application logic.
|
||||
func NewTunnelConnection(
|
||||
ctx context.Context,
|
||||
conn quic.Connection,
|
||||
conn cfdquic.QUICConnection,
|
||||
connIndex uint8,
|
||||
orchestrator Orchestrator,
|
||||
datagramSessionHandler DatagramSessionHandler,
|
||||
@@ -169,7 +169,7 @@ func (q *quicConnection) acceptStream(ctx context.Context) error {
|
||||
func (q *quicConnection) runStream(quicStream quic.Stream) {
|
||||
ctx := quicStream.Context()
|
||||
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
|
||||
defer stream.Close()
|
||||
defer func() { _ = stream.Close() }()
|
||||
|
||||
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
|
||||
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
|
||||
|
||||
@@ -8,9 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
@@ -24,7 +22,6 @@ import (
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
|
||||
)
|
||||
@@ -34,20 +31,18 @@ const (
|
||||
demuxChanCapacity = 16
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidDestinationIP = errors.New("unable to parse destination IP")
|
||||
)
|
||||
var errInvalidDestinationIP = pkgerrors.New("unable to parse destination IP")
|
||||
|
||||
// DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming
|
||||
// connection streams.
|
||||
type DatagramSessionHandler interface {
|
||||
Serve(context.Context) error
|
||||
|
||||
pogs.SessionManager
|
||||
tunnelpogs.SessionManager
|
||||
}
|
||||
|
||||
type datagramV2Connection struct {
|
||||
conn quic.Connection
|
||||
conn cfdquic.QUICConnection
|
||||
index uint8
|
||||
|
||||
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||
@@ -69,7 +64,7 @@ type datagramV2Connection struct {
|
||||
}
|
||||
|
||||
func NewDatagramV2Connection(ctx context.Context,
|
||||
conn quic.Connection,
|
||||
conn cfdquic.QUICConnection,
|
||||
originDialer ingress.OriginUDPDialer,
|
||||
icmpRouter ingress.ICMPRouter,
|
||||
index uint8,
|
||||
@@ -166,7 +161,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
||||
|
||||
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
|
||||
if err != nil {
|
||||
originProxy.Close()
|
||||
_ = originProxy.Close()
|
||||
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
q.flowLimiter.Release()
|
||||
@@ -229,7 +224,7 @@ func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uu
|
||||
}
|
||||
|
||||
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
|
||||
defer stream.Close()
|
||||
defer func() { _ = stream.Close() }()
|
||||
rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout)
|
||||
if err != nil {
|
||||
// Log this at debug because this is not an error if session was closed due to lost connection
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic/v3"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
cfdquicv3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
@@ -22,20 +22,20 @@ var (
|
||||
)
|
||||
|
||||
type datagramV3Connection struct {
|
||||
conn quic.Connection
|
||||
conn cfdquic.QUICConnection
|
||||
index uint8
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer cfdquic.DatagramConn
|
||||
metrics cfdquic.Metrics
|
||||
datagramMuxer cfdquicv3.DatagramConn
|
||||
metrics cfdquicv3.Metrics
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewDatagramV3Connection(ctx context.Context,
|
||||
conn quic.Connection,
|
||||
sessionManager cfdquic.SessionManager,
|
||||
conn cfdquic.QUICConnection,
|
||||
sessionManager cfdquicv3.SessionManager,
|
||||
icmpRouter ingress.ICMPRouter,
|
||||
index uint8,
|
||||
metrics cfdquic.Metrics,
|
||||
metrics cfdquicv3.Metrics,
|
||||
logger *zerolog.Logger,
|
||||
) DatagramSessionHandler {
|
||||
log := logger.
|
||||
@@ -43,7 +43,7 @@ func NewDatagramV3Connection(ctx context.Context,
|
||||
Int(management.EventTypeKey, int(management.UDP)).
|
||||
Uint8(LogFieldConnIndex, index).
|
||||
Logger()
|
||||
datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log)
|
||||
datagramMuxer := cfdquicv3.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log)
|
||||
|
||||
return &datagramV3Connection{
|
||||
conn,
|
||||
|
||||
@@ -17,12 +17,13 @@ import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
quic "github.com/quic-go/quic-go"
|
||||
quic0 "github.com/quic-go/quic-go"
|
||||
zerolog "github.com/rs/zerolog"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
|
||||
dialopts "github.com/cloudflare/cloudflared/connection/dialopts"
|
||||
allregions "github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||
quic "github.com/cloudflare/cloudflared/quic"
|
||||
)
|
||||
|
||||
// MockDNSResolver is a mock of DNSResolver interface.
|
||||
@@ -176,10 +177,10 @@ func (m *MockQUICDialer) EXPECT() *MockQUICDialerMockRecorder {
|
||||
}
|
||||
|
||||
// DialQuic mocks base method.
|
||||
func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.Connection, error) {
|
||||
func (m *MockQUICDialer) DialQuic(ctx context.Context, quicConfig *quic0.Config, tlsConfig *tls.Config, addr netip.AddrPort, localAddr net.IP, connIndex uint8, logger *zerolog.Logger, opts dialopts.DialOpts) (quic.QUICConnection, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DialQuic", ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
|
||||
ret0, _ := ret[0].(quic.Connection)
|
||||
ret0, _ := ret[0].(quic.QUICConnection)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -197,19 +198,19 @@ type MockQUICDialerDialQuicCall struct {
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.Connection, arg1 error) *MockQUICDialerDialQuicCall {
|
||||
func (c *MockQUICDialerDialQuicCall) Return(arg0 quic.QUICConnection, arg1 error) *MockQUICDialerDialQuicCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
|
||||
func (c *MockQUICDialerDialQuicCall) Do(f func(context.Context, *quic0.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.QUICConnection, error)) *MockQUICDialerDialQuicCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.Connection, error)) *MockQUICDialerDialQuicCall {
|
||||
func (c *MockQUICDialerDialQuicCall) DoAndReturn(f func(context.Context, *quic0.Config, *tls.Config, netip.AddrPort, net.IP, uint8, *zerolog.Logger, dialopts.DialOpts) (quic.QUICConnection, error)) *MockQUICDialerDialQuicCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
+2
-1
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
edgedial "github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
)
|
||||
|
||||
@@ -96,7 +97,7 @@ func (d *EdgeQUICDialer) DialQuic(
|
||||
connIndex uint8,
|
||||
logger *zerolog.Logger,
|
||||
opts dialopts.DialOpts,
|
||||
) (quic.Connection, error) {
|
||||
) (cfdquic.QUICConnection, error) {
|
||||
return connection.DialQuic(ctx, quicConfig, tlsConfig, addr, localAddr, connIndex, logger, opts)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection/dialopts"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
|
||||
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||
)
|
||||
@@ -44,7 +45,7 @@ type QUICDialer interface {
|
||||
connIndex uint8,
|
||||
logger *zerolog.Logger,
|
||||
opts dialopts.DialOpts,
|
||||
) (quic.Connection, error)
|
||||
) (cfdquic.QUICConnection, error)
|
||||
}
|
||||
|
||||
// ManagementDialer abstracts the TCP dial to api.cloudflare.com:443 used by
|
||||
|
||||
+10
-6
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
@@ -51,14 +50,14 @@ func (dm *DatagramMuxerV2) mtu() int {
|
||||
}
|
||||
|
||||
type DatagramMuxerV2 struct {
|
||||
session quic.Connection
|
||||
session QUICConnection
|
||||
logger *zerolog.Logger
|
||||
sessionDemuxChan chan<- *packet.Session
|
||||
packetDemuxChan chan Packet
|
||||
}
|
||||
|
||||
func NewDatagramMuxerV2(
|
||||
quicSession quic.Connection,
|
||||
quicSession QUICConnection,
|
||||
log *zerolog.Logger,
|
||||
sessionDemuxChan chan<- *packet.Session,
|
||||
) *DatagramMuxerV2 {
|
||||
@@ -110,7 +109,8 @@ func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
|
||||
// ServeReceive reads datagrams from the QUIC connection and demuxes them
|
||||
// depending on whether it's a session or packet
|
||||
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
|
||||
for {
|
||||
msg, err := dm.session.ReceiveDatagram(ctx)
|
||||
@@ -144,8 +144,10 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error
|
||||
switch msgType {
|
||||
case DatagramTypeUDP:
|
||||
return dm.handleSession(ctx, msg)
|
||||
default:
|
||||
case DatagramTypeIP, DatagramTypeIPWithTrace, DatagramTypeTracingSpan:
|
||||
return dm.handlePacket(ctx, msg, msgType)
|
||||
default:
|
||||
return fmt.Errorf("unexpected datagram type %d", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,8 +191,10 @@ func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType
|
||||
Spans: spans,
|
||||
TracingIdentity: tracingIdentity,
|
||||
}
|
||||
case DatagramTypeUDP:
|
||||
return fmt.Errorf("unexpected datagram type %d in handlePacket", msgType)
|
||||
default:
|
||||
return fmt.Errorf("Unexpected datagram type %d", msgType)
|
||||
return fmt.Errorf("unexpected datagram type %d", msgType)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// QUICConnection defines the subset of [quic.Connection] methods used by cloudflared.
|
||||
// Consumers should accept this interface; producers should return [*ConnWithCloser].
|
||||
type QUICConnection interface {
|
||||
AcceptStream(ctx context.Context) (quic.Stream, error)
|
||||
OpenStream() (quic.Stream, error)
|
||||
OpenStreamSync(ctx context.Context) (quic.Stream, error)
|
||||
CloseWithError(code quic.ApplicationErrorCode, reason string) error
|
||||
Context() context.Context
|
||||
SendDatagram(payload []byte) error
|
||||
ReceiveDatagram(ctx context.Context) ([]byte, error)
|
||||
LocalAddr() net.Addr
|
||||
RemoteAddr() net.Addr
|
||||
ConnectionState() quic.ConnectionState
|
||||
}
|
||||
|
||||
// Compile-time assertion that *ConnWithCloser implements QUICConnection.
|
||||
var _ QUICConnection = (*ConnWithCloser)(nil)
|
||||
|
||||
var (
|
||||
// error returned when the [NewConnWithCloser] is called with a nil conn argument
|
||||
ErrNilQuicConnection = errors.New("the provided quic connection is nil")
|
||||
// error returned when the [NewConnWithCloser] is called with a nil closer argument
|
||||
ErrNilCloser = errors.New("the provided closer is nil")
|
||||
)
|
||||
|
||||
// ConnWithCloser wraps a [quic.Connection] and an [io.Closer] (typically the
|
||||
// underlying [*net.UDPConn]). When [CloseWithError] is called the QUIC
|
||||
// connection is closed first, then the closer is closed deterministically.
|
||||
//
|
||||
// A nil conn is only safe for [CloseWithError] (used in tests). All other
|
||||
// delegated methods will panic on a nil conn.
|
||||
type ConnWithCloser struct {
|
||||
conn quic.Connection
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
// NewQUICConnection returns a [*ConnWithCloser] that will close closer after
|
||||
// the QUIC connection is closed.
|
||||
func NewQUICConnection(conn quic.Connection, closer io.Closer) (*ConnWithCloser, error) {
|
||||
if conn == nil {
|
||||
return nil, ErrNilQuicConnection
|
||||
}
|
||||
|
||||
if closer == nil {
|
||||
return nil, ErrNilCloser
|
||||
}
|
||||
return &ConnWithCloser{conn: conn, closer: closer}, nil
|
||||
}
|
||||
|
||||
// CloseWithError closes the QUIC connection and then closes the underlying
|
||||
// [io.Closer]. If both operations return errors, the errors are joined so that
|
||||
// the closer error is no longer silently discarded.
|
||||
func (c *ConnWithCloser) CloseWithError(code quic.ApplicationErrorCode, reason string) error {
|
||||
connErr := c.conn.CloseWithError(code, reason)
|
||||
closerErr := c.closer.Close()
|
||||
|
||||
return errors.Join(connErr, closerErr)
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) AcceptStream(ctx context.Context) (quic.Stream, error) {
|
||||
return c.conn.AcceptStream(ctx)
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) OpenStream() (quic.Stream, error) {
|
||||
return c.conn.OpenStream()
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) OpenStreamSync(ctx context.Context) (quic.Stream, error) {
|
||||
return c.conn.OpenStreamSync(ctx)
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) Context() context.Context {
|
||||
return c.conn.Context()
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) SendDatagram(payload []byte) error {
|
||||
return c.conn.SendDatagram(payload)
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
||||
return c.conn.ReceiveDatagram(ctx)
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *ConnWithCloser) ConnectionState() quic.ConnectionState {
|
||||
return c.conn.ConnectionState()
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockCloser is an [io.Closer] that returns a configurable error.
|
||||
type mockCloser struct {
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (m *mockCloser) Close() error {
|
||||
return m.closeErr
|
||||
}
|
||||
|
||||
// mockQuicConnection is a minimal test double for [quic.Connection].
|
||||
type mockQuicConnection struct {
|
||||
quic.Connection
|
||||
closeWithErrorErr error
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, _ string) error {
|
||||
return m.closeWithErrorErr
|
||||
}
|
||||
|
||||
func TestNewConnWithCloser_NilConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := NewQUICConnection(nil, &mockCloser{})
|
||||
require.ErrorIs(t, err, ErrNilQuicConnection)
|
||||
require.Nil(t, conn)
|
||||
}
|
||||
|
||||
func TestNewConnWithCloser_NilCloser(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := NewQUICConnection(&mockQuicConnection{}, nil)
|
||||
require.ErrorIs(t, err, ErrNilCloser)
|
||||
require.Nil(t, conn)
|
||||
}
|
||||
|
||||
func TestNewConnWithCloser_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
qc := &mockQuicConnection{}
|
||||
cl := &mockCloser{}
|
||||
conn, err := NewQUICConnection(qc, cl)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
}
|
||||
|
||||
func TestConnWithCloser_CloseWithError_BothSucceed(t *testing.T) {
|
||||
t.Parallel()
|
||||
qc := &mockQuicConnection{}
|
||||
cl := &mockCloser{}
|
||||
conn, err := NewQUICConnection(qc, cl)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.CloseWithError(0, "test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConnWithCloser_CloseWithError_QuicFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
quicErr := errors.New("quic close failed")
|
||||
qc := &mockQuicConnection{closeWithErrorErr: quicErr}
|
||||
cl := &mockCloser{}
|
||||
conn, err := NewQUICConnection(qc, cl)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.CloseWithError(0, "test")
|
||||
require.ErrorIs(t, err, quicErr)
|
||||
}
|
||||
|
||||
func TestConnWithCloser_CloseWithError_CloserFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
closerErr := errors.New("closer failed")
|
||||
qc := &mockQuicConnection{}
|
||||
cl := &mockCloser{closeErr: closerErr}
|
||||
conn, err := NewQUICConnection(qc, cl)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.CloseWithError(0, "test")
|
||||
require.ErrorIs(t, err, closerErr)
|
||||
}
|
||||
|
||||
func TestConnWithCloser_CloseWithError_BothFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
quicErr := errors.New("quic close failed")
|
||||
closerErr := errors.New("closer failed")
|
||||
qc := &mockQuicConnection{closeWithErrorErr: quicErr}
|
||||
cl := &mockCloser{closeErr: closerErr}
|
||||
conn, err := NewQUICConnection(qc, cl)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.CloseWithError(0, "test")
|
||||
require.ErrorIs(t, err, quicErr)
|
||||
require.ErrorIs(t, err, closerErr)
|
||||
}
|
||||
|
||||
// TestConnWithCloser_ImplementsInterface is a runtime assertion that
|
||||
// *ConnWithCloser satisfies QUICConnection. The compile-time assertion is in
|
||||
// quic_connection.go.
|
||||
func TestConnWithCloser_ImplementsInterface(t *testing.T) {
|
||||
t.Parallel()
|
||||
var _ QUICConnection = (*ConnWithCloser)(nil)
|
||||
}
|
||||
Reference in New Issue
Block a user