Files
endlessh-go/endlessh_integration_test.go

325 lines
8.4 KiB
Go

package main
import (
"bytes"
"fmt"
"net"
"os/exec"
"regexp"
"strings"
"sync"
"testing"
"time"
)
const (
waitForListenTimeout = 10 * time.Second
waitForConnectTimeout = 3 * time.Second
pollInterval = 50 * time.Millisecond
)
func waitForLogMatch(stderr *bytes.Buffer, pattern string, timeout time.Duration) bool {
re := regexp.MustCompile(pattern)
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
return false
}
if re.MatchString(stderr.String()) {
return true
}
time.Sleep(pollInterval)
}
}
func TestEndlesshIntegration_MultiplePorts(t *testing.T) {
const nPorts = 3
ports := make([]int, nPorts)
for i := 0; i < nPorts; i++ {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to get free port: %v", err)
}
ports[i] = ln.Addr().(*net.TCPAddr).Port
ln.Close()
}
args := []string{"run", "main.go",
"-interval_ms=100",
"-max_clients=10",
"-logtostderr",
"-v=1",
}
for _, p := range ports {
args = append(args, fmt.Sprintf("-port=%d", p))
}
cmd := exec.Command("go", args...)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer func() {
if err := cmd.Process.Kill(); err != nil {
t.Logf("Failed to kill process: %v", err)
}
}()
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
}
for _, port := range ports {
addr := fmt.Sprintf("localhost:%d", port)
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Failed to connect to server on port %d: %v", port, err)
}
if !waitForLogMatch(&stderr, `ACCEPT host=(127\.0\.0\.1|::1)`, waitForConnectTimeout) {
t.Errorf("Never saw any ACCEPT log, got logs: %s", stderr.String())
}
conn.Close()
if !waitForLogMatch(&stderr, `CLOSE host=(127\.0\.0\.1|::1)`, waitForConnectTimeout) {
t.Errorf("Never saw any CLOSE log, got logs: %s", stderr.String())
}
}
}
func TestEndlesshIntegration_TarpitBehavior(t *testing.T) {
var stderr bytes.Buffer
cmd := exec.Command("go", "run", "main.go", "-port=0", "-interval_ms=5000", "-max_clients=10", "-logtostderr", "-v=1")
cmd.Stderr = &stderr
if err := cmd.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer func() {
if err := cmd.Process.Kill(); err != nil {
t.Logf("Failed to kill process: %v", err)
}
}()
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
}
stderrOutput := stderr.String()
re := regexp.MustCompile(`Listening on .*:(\d+)`)
m := re.FindStringSubmatch(stderrOutput)
if len(m) != 2 {
t.Fatalf("Could not parse port from logs: %s", stderrOutput)
}
port := m[1]
addr := "localhost:" + port
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Failed to connect to server on port %s: %v", port, err)
}
defer conn.Close()
// Simulate SSH client banner
// Connect & send client banner
_, err = conn.Write([]byte("SSH-2.0-OpenSSH_8.2p1\r\n"))
if err != nil {
t.Fatalf("Write failed: %v", err)
}
// Expect FIRST LINE immediately (no delay)
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil || n == 0 {
t.Fatalf("Expected first tarpit line immediately, got: %v (%d bytes)", err, n)
}
t.Logf("Got first tarpit line (%d bytes): %q", n, buf[:n])
time.Sleep(100 * time.Millisecond)
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
n2, err := conn.Read(buf)
if err == nil && n2 > 0 {
t.Errorf("Got %d bytes too soon (within 500ms), tarpit failed", n2)
}
}
func TestEndlesshIntegration_Concurrency(t *testing.T) {
maxClients := 5
var stderr bytes.Buffer
cmd := exec.Command("go", "run", "main.go", "-port=0", "-interval_ms=1000", fmt.Sprintf("-max_clients=%d", maxClients), "-logtostderr", "-v=1")
cmd.Stderr = &stderr
if err := cmd.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer func() {
if err := cmd.Process.Kill(); err != nil {
t.Logf("Failed to kill process: %v", err)
}
}()
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
}
stderrOutput := stderr.String()
re := regexp.MustCompile(`Listening on .*:(\d+)`)
m := re.FindStringSubmatch(stderrOutput)
if len(m) != 2 {
t.Fatalf("Could not parse port from logs: %s", stderrOutput)
}
port := m[1]
addr := fmt.Sprintf("localhost:%s", port)
// Test multiple connections
var wg sync.WaitGroup
var mu sync.Mutex
activeClients := 0
maxActiveClients := 0
//sixthClientFailed := false
successfulReads := 0
for i := 0; i < maxClients+1; i++ {
wg.Add(1)
go func(clientID int) {
defer wg.Done()
conn, dialErr := net.Dial("tcp", addr)
if dialErr != nil {
if clientID == maxClients {
//sixthClientFailed = true
t.Logf("Client %d dial failed (expected): %v", clientID, dialErr)
} else {
t.Errorf("Client %d dial failed (unexpected): %v", clientID, dialErr)
}
return
}
defer conn.Close()
_, writeErr := conn.Write([]byte("SSH-2.0-OpenSSH_8.2p1\r\n"))
if writeErr != nil {
t.Logf("Client %d write failed: %v", clientID, writeErr)
return
}
buf := make([]byte, 1024)
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
n1, readErr1 := conn.Read(buf)
if readErr1 != nil || n1 == 0 {
if clientID != maxClients { // only log unexpected failures
t.Logf("Client %d failed first tarpit line (unexpected): %v (%d bytes)", clientID, readErr1, n1)
}
return
}
// Client is active only if it successfully received data
mu.Lock()
activeClients++
if activeClients > maxActiveClients {
maxActiveClients = activeClients
}
successfulReads++
mu.Unlock()
t.Logf("Client %d got line 1 (%d bytes): %q", clientID, n1, buf[:n1])
// Keep the connection open for a while
time.Sleep(5 * time.Second)
mu.Lock()
activeClients--
mu.Unlock()
}(i)
time.Sleep(200 * time.Millisecond)
}
wg.Wait()
mu.Lock()
deferredSuccessfulReads := successfulReads
mu.Unlock()
if deferredSuccessfulReads < 1 {
t.Errorf("Expected at least one client to receive a tarpit line, got %d", deferredSuccessfulReads)
}
// Check if maxActiveClients exceeded maxClients
if maxActiveClients > maxClients {
t.Errorf("Expected max %d concurrent clients, got %d", maxClients, maxActiveClients)
}
// Check if the 6th client failed
/*mu.Lock()
if !sixthClientFailed {
t.Errorf("Expected 6th client to fail, but it succeeded")
}
mu.Unlock()
*/
// The 6th client does not necessarily fail at Dial because max_clients is
// enforced at the protocol/Read stage instead of at Accept.
//
// TODO: Enforce max_clients at Accept so the N+1 client fails fast at Dial.
}
func TestEndlesshIntegration_PrometheusMetrics(t *testing.T) {
var stderr bytes.Buffer
cmd := exec.Command(
"go", "run", "main.go",
"-port=0",
"-enable_prometheus",
"-prometheus_port=0",
"-interval_ms=100",
"-logtostderr", "-v=1",
)
cmd.Stderr = &stderr
if err := cmd.Start(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer cmd.Process.Kill()
if !waitForLogMatch(&stderr, "Starting Prometheus", waitForListenTimeout) {
t.Fatalf("Prometheus did not start: %s", stderr.String())
}
reProm := regexp.MustCompile(`Prometheus on IP port .*:(\d+)`)
reMain := regexp.MustCompile(`Listening on .*:(\d+)`)
promMatch := reProm.FindStringSubmatch(stderr.String())
mainMatch := reMain.FindStringSubmatch(stderr.String())
if len(promMatch) < 2 || len(mainMatch) < 2 {
t.Fatalf("Could not parse ports: %s", stderr.String())
}
promPort := promMatch[1]
mainPort := mainMatch[1]
conn, err := net.Dial("tcp", "localhost:"+mainPort)
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
defer conn.Close()
conn.Write([]byte("SSH-2.0-test\r\n"))
time.Sleep(500 * time.Millisecond)
// Fetch metrics
resp, err := net.Dial("tcp", "localhost:"+promPort)
if err != nil {
t.Fatalf("Failed to connect to metrics endpoint: %v", err)
}
fmt.Fprintf(resp, "GET /metrics HTTP/1.1\r\nHost: localhost\r\n\r\n")
buf := make([]byte, 8192)
n, _ := resp.Read(buf)
body := string(buf[:n])
if !strings.Contains(body, "endlessh_client_open_count_total") {
t.Errorf("Missing expected metric in output:\n%s", body)
}
if !strings.Contains(body, "endlessh_sent_bytes_total") {
t.Errorf("Expected bytes metric not found:\n%s", body)
}
}