Merge pull request #188 from BenjaminGoehry/test/integration_test

test: add integration tests for multi-port, tarpit behavior, and max_clients limit
This commit is contained in:
Shizun Ge
2026-02-23 21:28:21 -08:00
committed by GitHub
4 changed files with 306 additions and 3 deletions
+18
View File
@@ -15,6 +15,24 @@ env:
PLATFORMS: "linux/amd64,linux/arm64,linux/arm/v7"
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.24.2'
cache: true
- name: Install dependencies
run: go mod download
- name: Run Tests
run: go test -v ./...
build_container_image:
name: Build Docker image
runs-on: ubuntu-latest
+17
View File
@@ -16,6 +16,23 @@ env:
PLATFORMS: "linux/amd64,linux/arm64,linux/arm/v7"
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.24.2'
cache: true
- name: Install dependencies
run: go mod download
- name: Run Tests
run: go test -v ./...
build_and_push:
name: Build and push Docker image
runs-on: ubuntu-latest
+258
View File
@@ -0,0 +1,258 @@
package main
import (
"bytes"
"fmt"
"net"
"os/exec"
"regexp"
"sync"
"testing"
"time"
)
const (
waitForListenTimeout = 10 * time.Second
waitForConnectTimeout = 3 * time.Second
pollInterval = 50 * time.Millisecond
)
func waitForLogMatch(stderr *bytes.Buffer, pattern string, timeout time.Duration) bool {
re := regexp.MustCompile(pattern)
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
return false
}
if re.MatchString(stderr.String()) {
return true
}
time.Sleep(pollInterval)
}
}
func TestEndlesshIntegration_MultiplePorts(t *testing.T) {
const nPorts = 3
ports := make([]int, nPorts)
for i := 0; i < nPorts; i++ {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to get free port: %v", err)
}
ports[i] = ln.Addr().(*net.TCPAddr).Port
ln.Close()
}
args := []string{"run", "main.go",
"-interval_ms=100",
"-max_clients=10",
"-logtostderr",
"-v=1",
}
for _, p := range ports {
args = append(args, fmt.Sprintf("-port=%d", p))
}
cmd := exec.Command("go", args...)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer func() {
if err := cmd.Process.Kill(); err != nil {
t.Logf("Failed to kill process: %v", err)
}
}()
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
}
for _, port := range ports {
addr := fmt.Sprintf("localhost:%d", port)
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Failed to connect to server on port %d: %v", port, err)
}
if !waitForLogMatch(&stderr, "ACCEPT host=127.0.0.1", waitForConnectTimeout) {
t.Errorf("Never saw any ACCEPT log, got logs: %s", stderr.String())
}
conn.Close()
if !waitForLogMatch(&stderr, "CLOSE host=127.0.0.1", waitForConnectTimeout) {
t.Errorf("Never saw any CLOSE log, got logs: %s", stderr.String())
}
}
}
func TestEndlesshIntegration_TarpitBehavior(t *testing.T) {
var stderr bytes.Buffer
cmd := exec.Command("go", "run", "main.go", "-port=0", "-interval_ms=5000", "-max_clients=10", "-logtostderr", "-v=1")
cmd.Stderr = &stderr
if err := cmd.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer func() {
if err := cmd.Process.Kill(); err != nil {
t.Logf("Failed to kill process: %v", err)
}
}()
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
}
stderrOutput := stderr.String()
re := regexp.MustCompile(`Listening on .*:(\d+)`)
m := re.FindStringSubmatch(stderrOutput)
if len(m) != 2 {
t.Fatalf("Could not parse port from logs: %s", stderrOutput)
}
port := m[1]
addr := "localhost:" + port
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Failed to connect to server on port %s: %v", port, err)
}
defer conn.Close()
// Simulate SSH client banner
// Connect & send client banner
_, err = conn.Write([]byte("SSH-2.0-OpenSSH_8.2p1\r\n"))
if err != nil {
t.Fatalf("Write failed: %v", err)
}
// Expect FIRST LINE immediately (no delay)
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil || n == 0 {
t.Fatalf("Expected first tarpit line immediately, got: %v (%d bytes)", err, n)
}
t.Logf("Got first tarpit line (%d bytes): %q", n, buf[:n])
time.Sleep(100 * time.Millisecond)
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
n2, err := conn.Read(buf)
if err == nil && n2 > 0 {
t.Errorf("Got %d bytes too soon (within 500ms), tarpit failed", n2)
}
}
func TestEndlesshIntegration_Concurrency(t *testing.T) {
maxClients := 5
var stderr bytes.Buffer
cmd := exec.Command("go", "run", "main.go", "-port=0", "-interval_ms=1000", fmt.Sprintf("-max_clients=%d", maxClients), "-logtostderr", "-v=1")
cmd.Stderr = &stderr
if err := cmd.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer func() {
if err := cmd.Process.Kill(); err != nil {
t.Logf("Failed to kill process: %v", err)
}
}()
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
}
stderrOutput := stderr.String()
re := regexp.MustCompile(`Listening on .*:(\d+)`)
m := re.FindStringSubmatch(stderrOutput)
if len(m) != 2 {
t.Fatalf("Could not parse port from logs: %s", stderrOutput)
}
port := m[1]
addr := fmt.Sprintf("localhost:%s", port)
// Test multiple connections
var wg sync.WaitGroup
var mu sync.Mutex
activeClients := 0
maxActiveClients := 0
//sixthClientFailed := false
successfulReads := 0
for i := 0; i < maxClients+1; i++ {
wg.Add(1)
go func(clientID int) {
defer wg.Done()
conn, dialErr := net.Dial("tcp", addr)
if dialErr != nil {
if clientID == maxClients {
//sixthClientFailed = true
t.Logf("Client %d dial failed (expected): %v", clientID, dialErr)
} else {
t.Errorf("Client %d dial failed (unexpected): %v", clientID, dialErr)
}
return
}
defer conn.Close()
_, writeErr := conn.Write([]byte("SSH-2.0-OpenSSH_8.2p1\r\n"))
if writeErr != nil {
t.Logf("Client %d write failed: %v", clientID, writeErr)
return
}
buf := make([]byte, 1024)
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
n1, readErr1 := conn.Read(buf)
if readErr1 != nil || n1 == 0 {
if clientID != maxClients { // only log unexpected failures
t.Logf("Client %d failed first tarpit line (unexpected): %v (%d bytes)", clientID, readErr1, n1)
}
return
}
// Client is active only if it successfully received data
mu.Lock()
activeClients++
if activeClients > maxActiveClients {
maxActiveClients = activeClients
}
successfulReads++
mu.Unlock()
t.Logf("Client %d got line 1 (%d bytes): %q", clientID, n1, buf[:n1])
// Keep the connection open for a while
time.Sleep(5 * time.Second)
mu.Lock()
activeClients--
mu.Unlock()
}(i)
time.Sleep(200 * time.Millisecond)
}
wg.Wait()
mu.Lock()
deferredSuccessfulReads := successfulReads
mu.Unlock()
if deferredSuccessfulReads < 1 {
t.Errorf("Expected at least one client to receive a tarpit line, got %d", deferredSuccessfulReads)
}
// Check if maxActiveClients exceeded maxClients
if maxActiveClients > maxClients {
t.Errorf("Expected max %d concurrent clients, got %d", maxClients, maxActiveClients)
}
// Check if the 6th client failed
/*mu.Lock()
if !sixthClientFailed {
t.Errorf("Expected 6th client to fail, but it succeeded")
}
mu.Unlock()
*/
// The 6th client does not necessarily fail at Dial because max_clients is
// enforced at the protocol/Read stage instead of at Accept.
//
// TODO: Enforce max_clients at Accept so the N+1 client fails fast at Dial.
}
+13 -3
View File
@@ -24,6 +24,7 @@ import (
"fmt"
"net"
"os"
"strconv"
"strings"
"time"
@@ -70,7 +71,14 @@ func startSending(maxClients int64, bannerMaxLength int64, records chan<- metric
func startAccepting(maxClients int64, connType, connHost, connPort string, interval time.Duration, clients chan<- *client.Client, records chan<- metrics.RecordEntry, proxyProtocolEnabled bool, proxyProtocolReadHeaderTimeout int) {
go func() {
l, err := net.Listen(connType, connHost+":"+connPort)
connPortInt, err := strconv.Atoi(connPort)
if err != nil {
glog.Errorf("Invalid port: %v", err)
os.Exit(1)
}
addr := connHost + fmt.Sprintf(":%d", connPortInt)
l, err := net.Listen(connType, addr)
if err != nil {
glog.Errorf("Error listening: %v", err)
os.Exit(1)
@@ -84,7 +92,9 @@ func startAccepting(maxClients int64, connType, connHost, connPort string, inter
// Close the listener when the application closes.
defer l.Close()
glog.Infof("Listening on %v:%v", connHost, connPort)
realAddr := l.Addr().(*net.TCPAddr)
localPortStr := strconv.Itoa(realAddr.Port)
glog.Infof("Listening on %v:%v", connHost, realAddr.Port)
for {
// Listen for an incoming connection.
conn, err := l.Accept()
@@ -97,7 +107,7 @@ func startAccepting(maxClients int64, connType, connHost, connPort string, inter
records <- metrics.RecordEntry{
RecordType: metrics.RecordEntryTypeStart,
IpAddr: remoteIpAddr,
LocalPort: connPort,
LocalPort: localPortStr,
}
clients <- c
}