mirror of
https://github.com/portainer/portainer.git
synced 2026-06-23 04:10:29 +00:00
fix(websocket): add proper locking and avoid goroutine leakage BE-12835 (#2303)
This commit is contained in:
@@ -40,7 +40,6 @@ func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *htt
|
||||
return err
|
||||
}
|
||||
|
||||
agentURL.Scheme = "ws"
|
||||
return handler.doProxyWebsocketRequest(w, r, params, agentURL, false)
|
||||
}
|
||||
|
||||
|
||||
+22
-5
@@ -41,7 +41,7 @@ func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.R
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
errorChan := make(chan error, 1)
|
||||
errorChan := make(chan error, 2)
|
||||
go StreamFromWebsocketToWriter(websocketConn, conn, errorChan)
|
||||
go WriteReaderToWebSocket(websocketConn, &mu, conn, errorChan)
|
||||
|
||||
@@ -52,7 +52,7 @@ func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.R
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Msg("session ended")
|
||||
log.Debug().Msg("session ended")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -77,15 +77,27 @@ func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reade
|
||||
out := make([]byte, ReaderBufferSize)
|
||||
input := make(chan string)
|
||||
pingTicker := time.NewTicker(PingPeriod)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
defer logs.CloseAndLogErr(websocketConn)
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
if err := websocketConn.SetReadDeadline(time.Now().Add(PingPeriod + WriteWait)); err != nil {
|
||||
errorChan <- err
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
websocketConn.SetPongHandler(func(string) error {
|
||||
return nil
|
||||
return websocketConn.SetReadDeadline(time.Now().Add(PingPeriod + WriteWait))
|
||||
})
|
||||
|
||||
websocketConn.SetPingHandler(func(data string) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if err := websocketConn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -108,7 +120,12 @@ func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reade
|
||||
}
|
||||
|
||||
processedOutput := ValidString(string(out[:n]))
|
||||
input <- processedOutput
|
||||
|
||||
select {
|
||||
case input <- processedOutput:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package ws
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -63,6 +64,13 @@ func StreamFromReaderToWebsocket(websocketConn *websocket.Conn, reader io.Reader
|
||||
}
|
||||
|
||||
processedOutput := ValidString(string(out[:n]))
|
||||
|
||||
if err := websocketConn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil {
|
||||
errorChan <- err
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if err := websocketConn.WriteMessage(websocket.TextMessage, []byte(processedOutput)); err != nil {
|
||||
errorChan <- err
|
||||
|
||||
|
||||
@@ -0,0 +1,436 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// newTestWSPair creates an httptest server, upgrades one connection to websocket,
|
||||
// and dials a client connection. Returns (server, client) connections
|
||||
func newTestWSPair(t *testing.T) (*websocket.Conn, *websocket.Conn) {
|
||||
t.Helper()
|
||||
|
||||
connCh := make(chan *websocket.Conn, 1)
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
connCh <- conn
|
||||
<-doneCh
|
||||
}))
|
||||
|
||||
t.Cleanup(func() {
|
||||
close(doneCh)
|
||||
srv.Close()
|
||||
})
|
||||
|
||||
u := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
|
||||
client, resp, err := websocket.DefaultDialer.Dial(u, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() { logs.CloseAndLogErr(resp.Body) })
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := <-connCh
|
||||
t.Cleanup(func() {
|
||||
err := server.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return server, client
|
||||
}
|
||||
|
||||
type failWriter struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (w *failWriter) Write(_ []byte) (int, error) {
|
||||
return 0, w.err
|
||||
}
|
||||
|
||||
type failReader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *failReader) Read(_ []byte) (int, error) {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
func TestValidString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := func(input, expected string) {
|
||||
t.Helper()
|
||||
result := ValidString(input)
|
||||
require.Equal(t, expected, result)
|
||||
require.True(t, utf8.ValidString(result))
|
||||
}
|
||||
|
||||
f("hello world", "hello world")
|
||||
// \xff and \xfe are invalid UTF-8 bytes and must be stripped
|
||||
f("hello\xff\xfeworld", "helloworld")
|
||||
f("", "")
|
||||
}
|
||||
|
||||
func TestStreamFromWebsocketToWriter_ForwardsMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, clientConn := newTestWSPair(t)
|
||||
|
||||
var buf bytes.Buffer
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go StreamFromWebsocketToWriter(serverConn, &buf, errorChan)
|
||||
|
||||
err := clientConn.WriteMessage(websocket.TextMessage, []byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
require.NoError(t, err)
|
||||
|
||||
// The goroutine sends to errorChan after processing all prior messages,
|
||||
// so buf is fully written by the time we receive here
|
||||
err = <-errorChan
|
||||
require.Error(t, err)
|
||||
|
||||
require.Equal(t, "hello", buf.String())
|
||||
}
|
||||
|
||||
func TestStreamFromWebsocketToWriter_HandlerInterceptsMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, clientConn := newTestWSPair(t)
|
||||
|
||||
var buf bytes.Buffer
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
intercepted := false
|
||||
handler := MessageHandler(func(_ int, data []byte) bool {
|
||||
if string(data) == "intercept" {
|
||||
intercepted = true
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
go StreamFromWebsocketToWriter(serverConn, &buf, errorChan, handler)
|
||||
|
||||
err := clientConn.WriteMessage(websocket.TextMessage, []byte("intercept"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = clientConn.WriteMessage(websocket.TextMessage, []byte("forward"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = <-errorChan
|
||||
require.Error(t, err)
|
||||
|
||||
require.True(t, intercepted)
|
||||
require.Equal(t, "forward", buf.String())
|
||||
}
|
||||
|
||||
func TestStreamFromWebsocketToWriter_WriteError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, clientConn := newTestWSPair(t)
|
||||
|
||||
expectedErr := errors.New("write error")
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go StreamFromWebsocketToWriter(serverConn, &failWriter{err: expectedErr}, errorChan)
|
||||
|
||||
err := clientConn.WriteMessage(websocket.TextMessage, []byte("trigger"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = <-errorChan
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
}
|
||||
|
||||
func TestStreamFromReaderToWebsocket_ForwardsData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, clientConn := newTestWSPair(t)
|
||||
|
||||
reader := strings.NewReader("hello world")
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go StreamFromReaderToWebsocket(serverConn, reader, errorChan)
|
||||
|
||||
msgType, msg, err := clientConn.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, websocket.TextMessage, msgType)
|
||||
require.Equal(t, "hello world", string(msg))
|
||||
|
||||
err = <-errorChan
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
func TestStreamFromReaderToWebsocket_ReadError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, _ := newTestWSPair(t)
|
||||
|
||||
expectedErr := errors.New("read error")
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go StreamFromReaderToWebsocket(serverConn, &failReader{err: expectedErr}, errorChan)
|
||||
|
||||
err := <-errorChan
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
}
|
||||
|
||||
func TestWriteReaderToWebSocket_ForwardsData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, clientConn := newTestWSPair(t)
|
||||
|
||||
var mu sync.Mutex
|
||||
reader := strings.NewReader("hello")
|
||||
errorChan := make(chan error, 2)
|
||||
|
||||
go WriteReaderToWebSocket(serverConn, &mu, reader, errorChan)
|
||||
|
||||
msgType, msg, err := clientConn.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, websocket.TextMessage, msgType)
|
||||
require.Equal(t, "hello", string(msg))
|
||||
|
||||
// The inner goroutine sends EOF after the reader is exhausted
|
||||
err = <-errorChan
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
func TestWriteReaderToWebSocket_ReaderError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, _ := newTestWSPair(t)
|
||||
|
||||
var mu sync.Mutex
|
||||
expectedErr := errors.New("read error")
|
||||
errorChan := make(chan error, 2)
|
||||
|
||||
go WriteReaderToWebSocket(serverConn, &mu, &failReader{err: expectedErr}, errorChan)
|
||||
|
||||
err := <-errorChan
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
}
|
||||
|
||||
// newTestWSConn creates a websocket server connection without registering a close cleanup
|
||||
// for it, allowing the caller to close it manually without triggering a double-close
|
||||
func newTestWSConn(t *testing.T) *websocket.Conn {
|
||||
t.Helper()
|
||||
|
||||
connCh := make(chan *websocket.Conn, 1)
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
connCh <- conn
|
||||
<-doneCh
|
||||
}))
|
||||
|
||||
t.Cleanup(func() {
|
||||
close(doneCh)
|
||||
srv.Close()
|
||||
})
|
||||
|
||||
u := "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||
|
||||
client, resp, err := websocket.DefaultDialer.Dial(u, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() { logs.CloseAndLogErr(resp.Body) })
|
||||
|
||||
t.Cleanup(func() { logs.CloseAndLogErr(client) })
|
||||
|
||||
return <-connCh
|
||||
}
|
||||
|
||||
type mockSizeQueue struct {
|
||||
cols uint16
|
||||
rows uint16
|
||||
}
|
||||
|
||||
func (m *mockSizeQueue) Push(cols, rows uint16) {
|
||||
m.cols = cols
|
||||
m.rows = rows
|
||||
}
|
||||
|
||||
func TestResizeHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := func(msgType int, data []byte, expectedHandled bool, expectedCols, expectedRows uint16) {
|
||||
t.Helper()
|
||||
q := &mockSizeQueue{}
|
||||
handled := ResizeHandler(q)(msgType, data)
|
||||
require.Equal(t, expectedHandled, handled)
|
||||
require.Equal(t, expectedCols, q.cols)
|
||||
require.Equal(t, expectedRows, q.rows)
|
||||
}
|
||||
|
||||
f(websocket.BinaryMessage, []byte(`{"type":"resize","data":{"width":80,"height":24}}`), false, 0, 0)
|
||||
f(websocket.TextMessage, []byte(`not json`), false, 0, 0)
|
||||
f(websocket.TextMessage, []byte(`{"type":"other","data":{"width":80,"height":24}}`), false, 0, 0)
|
||||
f(websocket.TextMessage, []byte(`{"type":"resize","data":{"width":80,"height":24}}`), true, 80, 24)
|
||||
}
|
||||
|
||||
func TestWsPing_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, _ := newTestWSPair(t)
|
||||
|
||||
var mu sync.Mutex
|
||||
err := wsPing(serverConn, &mu)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWsPing_ClosedConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn := newTestWSConn(t)
|
||||
|
||||
err := serverConn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
var mu sync.Mutex
|
||||
err = wsPing(serverConn, &mu)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWsWrite_ClosedConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn := newTestWSConn(t)
|
||||
|
||||
err := serverConn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
var mu sync.Mutex
|
||||
err = wsWrite(serverConn, &mu, "hello")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWriteReaderToWebSocket_ClosedConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn := newTestWSConn(t)
|
||||
|
||||
err := serverConn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
var mu sync.Mutex
|
||||
errorChan := make(chan error, 2)
|
||||
go WriteReaderToWebSocket(serverConn, &mu, strings.NewReader("hello"), errorChan)
|
||||
|
||||
err = <-errorChan
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStreamFromReaderToWebsocket_ClosedConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn := newTestWSConn(t)
|
||||
|
||||
err := serverConn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
errorChan := make(chan error, 1)
|
||||
go StreamFromReaderToWebsocket(serverConn, strings.NewReader("hello"), errorChan)
|
||||
|
||||
err = <-errorChan
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHijackRequest_UnexpectedStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, _ := newTestWSPair(t)
|
||||
|
||||
backendServer, backendClient := net.Pipe()
|
||||
defer logs.CloseAndLogErr(backendServer)
|
||||
defer logs.CloseAndLogErr(backendClient)
|
||||
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
if _, err := backendServer.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp := "HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n"
|
||||
if _, err := backendServer.Write([]byte(resp)); err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://backend/", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = HijackRequest(serverConn, backendClient, req)
|
||||
require.ErrorContains(t, err, "403")
|
||||
}
|
||||
|
||||
func TestHijackRequest_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConn, _ := newTestWSPair(t)
|
||||
|
||||
backendServer, backendClient := net.Pipe()
|
||||
defer logs.CloseAndLogErr(backendClient)
|
||||
|
||||
go func() {
|
||||
defer logs.CloseAndLogErr(backendServer)
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
if _, err := backendServer.Read(buf); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"
|
||||
if _, err := backendServer.Write([]byte(resp)); err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://backend/", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The backend closes after sending 101, which delivers EOF to the reader goroutine.
|
||||
// HijackRequest treats EOF (not a CloseError) as a clean session end and returns nil.
|
||||
err = HijackRequest(serverConn, backendClient, req)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
Reference in New Issue
Block a user