fix(websocket): add proper locking and avoid goroutine leakage BE-12835 (#2303)

This commit is contained in:
andres-portainer
2026-04-16 14:08:51 -03:00
committed by GitHub
parent 6ecb99898d
commit b417b04a69
4 changed files with 466 additions and 6 deletions
-1
View File
@@ -40,7 +40,6 @@ func (handler *Handler) proxyAgentWebsocketRequest(w http.ResponseWriter, r *htt
return err
}
agentURL.Scheme = "ws"
return handler.doProxyWebsocketRequest(w, r, params, agentURL, false)
}
+22 -5
View File
@@ -41,7 +41,7 @@ func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.R
var mu sync.Mutex
errorChan := make(chan error, 1)
errorChan := make(chan error, 2)
go StreamFromWebsocketToWriter(websocketConn, conn, errorChan)
go WriteReaderToWebSocket(websocketConn, &mu, conn, errorChan)
@@ -52,7 +52,7 @@ func HijackRequest(websocketConn *websocket.Conn, conn net.Conn, request *http.R
return err
}
log.Info().Msg("session ended")
log.Debug().Msg("session ended")
return nil
}
@@ -77,15 +77,27 @@ func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reade
out := make([]byte, ReaderBufferSize)
input := make(chan string)
pingTicker := time.NewTicker(PingPeriod)
defer pingTicker.Stop()
defer logs.CloseAndLogErr(websocketConn)
done := make(chan struct{})
defer close(done)
if err := websocketConn.SetReadDeadline(time.Now().Add(PingPeriod + WriteWait)); err != nil {
errorChan <- err
return
}
mu.Lock()
websocketConn.SetPongHandler(func(string) error {
return nil
return websocketConn.SetReadDeadline(time.Now().Add(PingPeriod + WriteWait))
})
websocketConn.SetPingHandler(func(data string) error {
mu.Lock()
defer mu.Unlock()
if err := websocketConn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil {
return err
}
@@ -108,7 +120,12 @@ func WriteReaderToWebSocket(websocketConn *websocket.Conn, mu *sync.Mutex, reade
}
processedOutput := ValidString(string(out[:n]))
input <- processedOutput
select {
case input <- processedOutput:
case <-done:
return
}
}
}()
+8
View File
@@ -2,6 +2,7 @@ package ws
import (
"io"
"time"
"unicode/utf8"
"github.com/gorilla/websocket"
@@ -63,6 +64,13 @@ func StreamFromReaderToWebsocket(websocketConn *websocket.Conn, reader io.Reader
}
processedOutput := ValidString(string(out[:n]))
if err := websocketConn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil {
errorChan <- err
break
}
if err := websocketConn.WriteMessage(websocket.TextMessage, []byte(processedOutput)); err != nil {
errorChan <- err
+436
View File
@@ -0,0 +1,436 @@
package ws
import (
"bytes"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"unicode/utf8"
"github.com/portainer/portainer/api/logs"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var testUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
// newTestWSPair creates an httptest server, upgrades one connection to websocket,
// and dials a client connection. Returns (server, client) connections
func newTestWSPair(t *testing.T) (*websocket.Conn, *websocket.Conn) {
t.Helper()
connCh := make(chan *websocket.Conn, 1)
doneCh := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if !assert.NoError(t, err) {
return
}
connCh <- conn
<-doneCh
}))
t.Cleanup(func() {
close(doneCh)
srv.Close()
})
u := "ws" + strings.TrimPrefix(srv.URL, "http")
client, resp, err := websocket.DefaultDialer.Dial(u, nil)
require.NoError(t, err)
t.Cleanup(func() { logs.CloseAndLogErr(resp.Body) })
t.Cleanup(func() {
err := client.Close()
require.NoError(t, err)
})
server := <-connCh
t.Cleanup(func() {
err := server.Close()
require.NoError(t, err)
})
return server, client
}
type failWriter struct {
err error
}
func (w *failWriter) Write(_ []byte) (int, error) {
return 0, w.err
}
type failReader struct {
err error
}
func (r *failReader) Read(_ []byte) (int, error) {
return 0, r.err
}
func TestValidString(t *testing.T) {
t.Parallel()
f := func(input, expected string) {
t.Helper()
result := ValidString(input)
require.Equal(t, expected, result)
require.True(t, utf8.ValidString(result))
}
f("hello world", "hello world")
// \xff and \xfe are invalid UTF-8 bytes and must be stripped
f("hello\xff\xfeworld", "helloworld")
f("", "")
}
func TestStreamFromWebsocketToWriter_ForwardsMessages(t *testing.T) {
t.Parallel()
serverConn, clientConn := newTestWSPair(t)
var buf bytes.Buffer
errorChan := make(chan error, 1)
go StreamFromWebsocketToWriter(serverConn, &buf, errorChan)
err := clientConn.WriteMessage(websocket.TextMessage, []byte("hello"))
require.NoError(t, err)
err = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
require.NoError(t, err)
// The goroutine sends to errorChan after processing all prior messages,
// so buf is fully written by the time we receive here
err = <-errorChan
require.Error(t, err)
require.Equal(t, "hello", buf.String())
}
func TestStreamFromWebsocketToWriter_HandlerInterceptsMessage(t *testing.T) {
t.Parallel()
serverConn, clientConn := newTestWSPair(t)
var buf bytes.Buffer
errorChan := make(chan error, 1)
intercepted := false
handler := MessageHandler(func(_ int, data []byte) bool {
if string(data) == "intercept" {
intercepted = true
return true
}
return false
})
go StreamFromWebsocketToWriter(serverConn, &buf, errorChan, handler)
err := clientConn.WriteMessage(websocket.TextMessage, []byte("intercept"))
require.NoError(t, err)
err = clientConn.WriteMessage(websocket.TextMessage, []byte("forward"))
require.NoError(t, err)
err = clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
require.NoError(t, err)
err = <-errorChan
require.Error(t, err)
require.True(t, intercepted)
require.Equal(t, "forward", buf.String())
}
func TestStreamFromWebsocketToWriter_WriteError(t *testing.T) {
t.Parallel()
serverConn, clientConn := newTestWSPair(t)
expectedErr := errors.New("write error")
errorChan := make(chan error, 1)
go StreamFromWebsocketToWriter(serverConn, &failWriter{err: expectedErr}, errorChan)
err := clientConn.WriteMessage(websocket.TextMessage, []byte("trigger"))
require.NoError(t, err)
err = <-errorChan
require.ErrorIs(t, err, expectedErr)
}
func TestStreamFromReaderToWebsocket_ForwardsData(t *testing.T) {
t.Parallel()
serverConn, clientConn := newTestWSPair(t)
reader := strings.NewReader("hello world")
errorChan := make(chan error, 1)
go StreamFromReaderToWebsocket(serverConn, reader, errorChan)
msgType, msg, err := clientConn.ReadMessage()
require.NoError(t, err)
require.Equal(t, websocket.TextMessage, msgType)
require.Equal(t, "hello world", string(msg))
err = <-errorChan
require.ErrorIs(t, err, io.EOF)
}
func TestStreamFromReaderToWebsocket_ReadError(t *testing.T) {
t.Parallel()
serverConn, _ := newTestWSPair(t)
expectedErr := errors.New("read error")
errorChan := make(chan error, 1)
go StreamFromReaderToWebsocket(serverConn, &failReader{err: expectedErr}, errorChan)
err := <-errorChan
require.ErrorIs(t, err, expectedErr)
}
func TestWriteReaderToWebSocket_ForwardsData(t *testing.T) {
t.Parallel()
serverConn, clientConn := newTestWSPair(t)
var mu sync.Mutex
reader := strings.NewReader("hello")
errorChan := make(chan error, 2)
go WriteReaderToWebSocket(serverConn, &mu, reader, errorChan)
msgType, msg, err := clientConn.ReadMessage()
require.NoError(t, err)
require.Equal(t, websocket.TextMessage, msgType)
require.Equal(t, "hello", string(msg))
// The inner goroutine sends EOF after the reader is exhausted
err = <-errorChan
require.ErrorIs(t, err, io.EOF)
}
func TestWriteReaderToWebSocket_ReaderError(t *testing.T) {
t.Parallel()
serverConn, _ := newTestWSPair(t)
var mu sync.Mutex
expectedErr := errors.New("read error")
errorChan := make(chan error, 2)
go WriteReaderToWebSocket(serverConn, &mu, &failReader{err: expectedErr}, errorChan)
err := <-errorChan
require.ErrorIs(t, err, expectedErr)
}
// newTestWSConn creates a websocket server connection without registering a close cleanup
// for it, allowing the caller to close it manually without triggering a double-close
func newTestWSConn(t *testing.T) *websocket.Conn {
t.Helper()
connCh := make(chan *websocket.Conn, 1)
doneCh := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if !assert.NoError(t, err) {
return
}
connCh <- conn
<-doneCh
}))
t.Cleanup(func() {
close(doneCh)
srv.Close()
})
u := "ws" + strings.TrimPrefix(srv.URL, "http")
client, resp, err := websocket.DefaultDialer.Dial(u, nil)
require.NoError(t, err)
t.Cleanup(func() { logs.CloseAndLogErr(resp.Body) })
t.Cleanup(func() { logs.CloseAndLogErr(client) })
return <-connCh
}
type mockSizeQueue struct {
cols uint16
rows uint16
}
func (m *mockSizeQueue) Push(cols, rows uint16) {
m.cols = cols
m.rows = rows
}
func TestResizeHandler(t *testing.T) {
t.Parallel()
f := func(msgType int, data []byte, expectedHandled bool, expectedCols, expectedRows uint16) {
t.Helper()
q := &mockSizeQueue{}
handled := ResizeHandler(q)(msgType, data)
require.Equal(t, expectedHandled, handled)
require.Equal(t, expectedCols, q.cols)
require.Equal(t, expectedRows, q.rows)
}
f(websocket.BinaryMessage, []byte(`{"type":"resize","data":{"width":80,"height":24}}`), false, 0, 0)
f(websocket.TextMessage, []byte(`not json`), false, 0, 0)
f(websocket.TextMessage, []byte(`{"type":"other","data":{"width":80,"height":24}}`), false, 0, 0)
f(websocket.TextMessage, []byte(`{"type":"resize","data":{"width":80,"height":24}}`), true, 80, 24)
}
func TestWsPing_Success(t *testing.T) {
t.Parallel()
serverConn, _ := newTestWSPair(t)
var mu sync.Mutex
err := wsPing(serverConn, &mu)
require.NoError(t, err)
}
func TestWsPing_ClosedConnection(t *testing.T) {
t.Parallel()
serverConn := newTestWSConn(t)
err := serverConn.Close()
require.NoError(t, err)
var mu sync.Mutex
err = wsPing(serverConn, &mu)
require.Error(t, err)
}
func TestWsWrite_ClosedConnection(t *testing.T) {
t.Parallel()
serverConn := newTestWSConn(t)
err := serverConn.Close()
require.NoError(t, err)
var mu sync.Mutex
err = wsWrite(serverConn, &mu, "hello")
require.Error(t, err)
}
func TestWriteReaderToWebSocket_ClosedConnection(t *testing.T) {
t.Parallel()
serverConn := newTestWSConn(t)
err := serverConn.Close()
require.NoError(t, err)
var mu sync.Mutex
errorChan := make(chan error, 2)
go WriteReaderToWebSocket(serverConn, &mu, strings.NewReader("hello"), errorChan)
err = <-errorChan
require.Error(t, err)
}
func TestStreamFromReaderToWebsocket_ClosedConnection(t *testing.T) {
t.Parallel()
serverConn := newTestWSConn(t)
err := serverConn.Close()
require.NoError(t, err)
errorChan := make(chan error, 1)
go StreamFromReaderToWebsocket(serverConn, strings.NewReader("hello"), errorChan)
err = <-errorChan
require.Error(t, err)
}
func TestHijackRequest_UnexpectedStatus(t *testing.T) {
t.Parallel()
serverConn, _ := newTestWSPair(t)
backendServer, backendClient := net.Pipe()
defer logs.CloseAndLogErr(backendServer)
defer logs.CloseAndLogErr(backendClient)
go func() {
buf := make([]byte, 4096)
if _, err := backendServer.Read(buf); err != nil {
return
}
resp := "HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n"
if _, err := backendServer.Write([]byte(resp)); err != nil {
return
}
}()
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://backend/", nil)
require.NoError(t, err)
err = HijackRequest(serverConn, backendClient, req)
require.ErrorContains(t, err, "403")
}
func TestHijackRequest_Success(t *testing.T) {
t.Parallel()
serverConn, _ := newTestWSPair(t)
backendServer, backendClient := net.Pipe()
defer logs.CloseAndLogErr(backendClient)
go func() {
defer logs.CloseAndLogErr(backendServer)
buf := make([]byte, 4096)
if _, err := backendServer.Read(buf); err != nil {
return
}
resp := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"
if _, err := backendServer.Write([]byte(resp)); err != nil {
return
}
}()
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://backend/", nil)
require.NoError(t, err)
// The backend closes after sending 101, which delivers EOF to the reader goroutine.
// HijackRequest treats EOF (not a CloseError) as a clean session end and returns nil.
err = HijackRequest(serverConn, backendClient, req)
require.NoError(t, err)
}