mirror of
https://github.com/shizunge/endlessh-go.git
synced 2026-06-23 04:10:08 +00:00
Merge pull request #188 from BenjaminGoehry/test/integration_test
test: add integration tests for multi-port, tarpit behavior, and max_clients limit
This commit is contained in:
@@ -15,6 +15,24 @@ env:
|
||||
PLATFORMS: "linux/amd64,linux/arm64,linux/arm/v7"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.24.2'
|
||||
cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run Tests
|
||||
run: go test -v ./...
|
||||
|
||||
build_container_image:
|
||||
name: Build Docker image
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -16,6 +16,23 @@ env:
|
||||
PLATFORMS: "linux/amd64,linux/arm64,linux/arm/v7"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.24.2'
|
||||
cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run Tests
|
||||
run: go test -v ./...
|
||||
build_and_push:
|
||||
name: Build and push Docker image
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
waitForListenTimeout = 10 * time.Second
|
||||
waitForConnectTimeout = 3 * time.Second
|
||||
pollInterval = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
func waitForLogMatch(stderr *bytes.Buffer, pattern string, timeout time.Duration) bool {
|
||||
re := regexp.MustCompile(pattern)
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
return false
|
||||
}
|
||||
if re.MatchString(stderr.String()) {
|
||||
return true
|
||||
}
|
||||
time.Sleep(pollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndlesshIntegration_MultiplePorts(t *testing.T) {
|
||||
const nPorts = 3
|
||||
ports := make([]int, nPorts)
|
||||
for i := 0; i < nPorts; i++ {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get free port: %v", err)
|
||||
}
|
||||
ports[i] = ln.Addr().(*net.TCPAddr).Port
|
||||
ln.Close()
|
||||
}
|
||||
args := []string{"run", "main.go",
|
||||
"-interval_ms=100",
|
||||
"-max_clients=10",
|
||||
"-logtostderr",
|
||||
"-v=1",
|
||||
}
|
||||
for _, p := range ports {
|
||||
args = append(args, fmt.Sprintf("-port=%d", p))
|
||||
}
|
||||
|
||||
cmd := exec.Command("go", args...)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
t.Logf("Failed to kill process: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
|
||||
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
|
||||
}
|
||||
for _, port := range ports {
|
||||
addr := fmt.Sprintf("localhost:%d", port)
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to server on port %d: %v", port, err)
|
||||
}
|
||||
if !waitForLogMatch(&stderr, "ACCEPT host=127.0.0.1", waitForConnectTimeout) {
|
||||
t.Errorf("Never saw any ACCEPT log, got logs: %s", stderr.String())
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
if !waitForLogMatch(&stderr, "CLOSE host=127.0.0.1", waitForConnectTimeout) {
|
||||
t.Errorf("Never saw any CLOSE log, got logs: %s", stderr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndlesshIntegration_TarpitBehavior(t *testing.T) {
|
||||
var stderr bytes.Buffer
|
||||
cmd := exec.Command("go", "run", "main.go", "-port=0", "-interval_ms=5000", "-max_clients=10", "-logtostderr", "-v=1")
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
t.Logf("Failed to kill process: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
|
||||
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
|
||||
}
|
||||
|
||||
stderrOutput := stderr.String()
|
||||
re := regexp.MustCompile(`Listening on .*:(\d+)`)
|
||||
m := re.FindStringSubmatch(stderrOutput)
|
||||
if len(m) != 2 {
|
||||
t.Fatalf("Could not parse port from logs: %s", stderrOutput)
|
||||
}
|
||||
port := m[1]
|
||||
addr := "localhost:" + port
|
||||
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to server on port %s: %v", port, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Simulate SSH client banner
|
||||
// Connect & send client banner
|
||||
_, err = conn.Write([]byte("SSH-2.0-OpenSSH_8.2p1\r\n"))
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
// Expect FIRST LINE immediately (no delay)
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil || n == 0 {
|
||||
t.Fatalf("Expected first tarpit line immediately, got: %v (%d bytes)", err, n)
|
||||
}
|
||||
t.Logf("Got first tarpit line (%d bytes): %q", n, buf[:n])
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
n2, err := conn.Read(buf)
|
||||
if err == nil && n2 > 0 {
|
||||
t.Errorf("Got %d bytes too soon (within 500ms), tarpit failed", n2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndlesshIntegration_Concurrency(t *testing.T) {
|
||||
maxClients := 5
|
||||
var stderr bytes.Buffer
|
||||
cmd := exec.Command("go", "run", "main.go", "-port=0", "-interval_ms=1000", fmt.Sprintf("-max_clients=%d", maxClients), "-logtostderr", "-v=1")
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
t.Logf("Failed to kill process: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
|
||||
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
|
||||
}
|
||||
|
||||
stderrOutput := stderr.String()
|
||||
re := regexp.MustCompile(`Listening on .*:(\d+)`)
|
||||
m := re.FindStringSubmatch(stderrOutput)
|
||||
if len(m) != 2 {
|
||||
t.Fatalf("Could not parse port from logs: %s", stderrOutput)
|
||||
}
|
||||
port := m[1]
|
||||
addr := fmt.Sprintf("localhost:%s", port)
|
||||
|
||||
// Test multiple connections
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
activeClients := 0
|
||||
maxActiveClients := 0
|
||||
//sixthClientFailed := false
|
||||
successfulReads := 0
|
||||
|
||||
for i := 0; i < maxClients+1; i++ {
|
||||
wg.Add(1)
|
||||
go func(clientID int) {
|
||||
defer wg.Done()
|
||||
|
||||
conn, dialErr := net.Dial("tcp", addr)
|
||||
if dialErr != nil {
|
||||
if clientID == maxClients {
|
||||
//sixthClientFailed = true
|
||||
t.Logf("Client %d dial failed (expected): %v", clientID, dialErr)
|
||||
} else {
|
||||
t.Errorf("Client %d dial failed (unexpected): %v", clientID, dialErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, writeErr := conn.Write([]byte("SSH-2.0-OpenSSH_8.2p1\r\n"))
|
||||
if writeErr != nil {
|
||||
t.Logf("Client %d write failed: %v", clientID, writeErr)
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
n1, readErr1 := conn.Read(buf)
|
||||
if readErr1 != nil || n1 == 0 {
|
||||
if clientID != maxClients { // only log unexpected failures
|
||||
t.Logf("Client %d failed first tarpit line (unexpected): %v (%d bytes)", clientID, readErr1, n1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Client is active only if it successfully received data
|
||||
mu.Lock()
|
||||
activeClients++
|
||||
if activeClients > maxActiveClients {
|
||||
maxActiveClients = activeClients
|
||||
}
|
||||
successfulReads++
|
||||
mu.Unlock()
|
||||
|
||||
t.Logf("Client %d got line 1 (%d bytes): %q", clientID, n1, buf[:n1])
|
||||
|
||||
// Keep the connection open for a while
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
mu.Lock()
|
||||
activeClients--
|
||||
mu.Unlock()
|
||||
}(i)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
deferredSuccessfulReads := successfulReads
|
||||
mu.Unlock()
|
||||
|
||||
if deferredSuccessfulReads < 1 {
|
||||
t.Errorf("Expected at least one client to receive a tarpit line, got %d", deferredSuccessfulReads)
|
||||
}
|
||||
|
||||
// Check if maxActiveClients exceeded maxClients
|
||||
if maxActiveClients > maxClients {
|
||||
t.Errorf("Expected max %d concurrent clients, got %d", maxClients, maxActiveClients)
|
||||
}
|
||||
|
||||
// Check if the 6th client failed
|
||||
/*mu.Lock()
|
||||
if !sixthClientFailed {
|
||||
t.Errorf("Expected 6th client to fail, but it succeeded")
|
||||
}
|
||||
mu.Unlock()
|
||||
*/
|
||||
// The 6th client does not necessarily fail at Dial because max_clients is
|
||||
// enforced at the protocol/Read stage instead of at Accept.
|
||||
//
|
||||
// TODO: Enforce max_clients at Accept so the N+1 client fails fast at Dial.
|
||||
}
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -70,7 +71,14 @@ func startSending(maxClients int64, bannerMaxLength int64, records chan<- metric
|
||||
|
||||
func startAccepting(maxClients int64, connType, connHost, connPort string, interval time.Duration, clients chan<- *client.Client, records chan<- metrics.RecordEntry, proxyProtocolEnabled bool, proxyProtocolReadHeaderTimeout int) {
|
||||
go func() {
|
||||
l, err := net.Listen(connType, connHost+":"+connPort)
|
||||
connPortInt, err := strconv.Atoi(connPort)
|
||||
if err != nil {
|
||||
glog.Errorf("Invalid port: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
addr := connHost + fmt.Sprintf(":%d", connPortInt)
|
||||
l, err := net.Listen(connType, addr)
|
||||
if err != nil {
|
||||
glog.Errorf("Error listening: %v", err)
|
||||
os.Exit(1)
|
||||
@@ -84,7 +92,9 @@ func startAccepting(maxClients int64, connType, connHost, connPort string, inter
|
||||
// Close the listener when the application closes.
|
||||
defer l.Close()
|
||||
|
||||
glog.Infof("Listening on %v:%v", connHost, connPort)
|
||||
realAddr := l.Addr().(*net.TCPAddr)
|
||||
localPortStr := strconv.Itoa(realAddr.Port)
|
||||
glog.Infof("Listening on %v:%v", connHost, realAddr.Port)
|
||||
for {
|
||||
// Listen for an incoming connection.
|
||||
conn, err := l.Accept()
|
||||
@@ -97,7 +107,7 @@ func startAccepting(maxClients int64, connType, connHost, connPort string, inter
|
||||
records <- metrics.RecordEntry{
|
||||
RecordType: metrics.RecordEntryTypeStart,
|
||||
IpAddr: remoteIpAddr,
|
||||
LocalPort: connPort,
|
||||
LocalPort: localPortStr,
|
||||
}
|
||||
clients <- c
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user