mirror of
https://github.com/shizunge/endlessh-go.git
synced 2026-06-22 20:00:27 +00:00
325 lines
8.4 KiB
Go
325 lines
8.4 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"net"
|
|
"os/exec"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
waitForListenTimeout = 10 * time.Second
|
|
waitForConnectTimeout = 3 * time.Second
|
|
pollInterval = 50 * time.Millisecond
|
|
)
|
|
|
|
func waitForLogMatch(stderr *bytes.Buffer, pattern string, timeout time.Duration) bool {
|
|
re := regexp.MustCompile(pattern)
|
|
deadline := time.Now().Add(timeout)
|
|
for {
|
|
if time.Now().After(deadline) {
|
|
return false
|
|
}
|
|
if re.MatchString(stderr.String()) {
|
|
return true
|
|
}
|
|
time.Sleep(pollInterval)
|
|
}
|
|
}
|
|
|
|
func TestEndlesshIntegration_MultiplePorts(t *testing.T) {
|
|
const nPorts = 3
|
|
ports := make([]int, nPorts)
|
|
for i := 0; i < nPorts; i++ {
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("failed to get free port: %v", err)
|
|
}
|
|
ports[i] = ln.Addr().(*net.TCPAddr).Port
|
|
ln.Close()
|
|
}
|
|
args := []string{"run", "main.go",
|
|
"-interval_ms=100",
|
|
"-max_clients=10",
|
|
"-logtostderr",
|
|
"-v=1",
|
|
}
|
|
for _, p := range ports {
|
|
args = append(args, fmt.Sprintf("-port=%d", p))
|
|
}
|
|
|
|
cmd := exec.Command("go", args...)
|
|
var stderr bytes.Buffer
|
|
cmd.Stderr = &stderr
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
t.Fatalf("Failed to start server: %v", err)
|
|
}
|
|
defer func() {
|
|
if err := cmd.Process.Kill(); err != nil {
|
|
t.Logf("Failed to kill process: %v", err)
|
|
}
|
|
}()
|
|
|
|
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
|
|
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
|
|
}
|
|
for _, port := range ports {
|
|
addr := fmt.Sprintf("localhost:%d", port)
|
|
conn, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to server on port %d: %v", port, err)
|
|
}
|
|
if !waitForLogMatch(&stderr, `ACCEPT host=(127\.0\.0\.1|::1)`, waitForConnectTimeout) {
|
|
t.Errorf("Never saw any ACCEPT log, got logs: %s", stderr.String())
|
|
}
|
|
conn.Close()
|
|
|
|
if !waitForLogMatch(&stderr, `CLOSE host=(127\.0\.0\.1|::1)`, waitForConnectTimeout) {
|
|
t.Errorf("Never saw any CLOSE log, got logs: %s", stderr.String())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestEndlesshIntegration_TarpitBehavior(t *testing.T) {
|
|
var stderr bytes.Buffer
|
|
cmd := exec.Command("go", "run", "main.go", "-port=0", "-interval_ms=5000", "-max_clients=10", "-logtostderr", "-v=1")
|
|
cmd.Stderr = &stderr
|
|
if err := cmd.Start(); err != nil {
|
|
t.Fatalf("Failed to start server: %v", err)
|
|
}
|
|
defer func() {
|
|
if err := cmd.Process.Kill(); err != nil {
|
|
t.Logf("Failed to kill process: %v", err)
|
|
}
|
|
}()
|
|
|
|
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
|
|
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
|
|
}
|
|
|
|
stderrOutput := stderr.String()
|
|
re := regexp.MustCompile(`Listening on .*:(\d+)`)
|
|
m := re.FindStringSubmatch(stderrOutput)
|
|
if len(m) != 2 {
|
|
t.Fatalf("Could not parse port from logs: %s", stderrOutput)
|
|
}
|
|
port := m[1]
|
|
addr := "localhost:" + port
|
|
|
|
conn, err := net.Dial("tcp", addr)
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to server on port %s: %v", port, err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Simulate SSH client banner
|
|
// Connect & send client banner
|
|
_, err = conn.Write([]byte("SSH-2.0-OpenSSH_8.2p1\r\n"))
|
|
if err != nil {
|
|
t.Fatalf("Write failed: %v", err)
|
|
}
|
|
|
|
// Expect FIRST LINE immediately (no delay)
|
|
buf := make([]byte, 1024)
|
|
n, err := conn.Read(buf)
|
|
if err != nil || n == 0 {
|
|
t.Fatalf("Expected first tarpit line immediately, got: %v (%d bytes)", err, n)
|
|
}
|
|
t.Logf("Got first tarpit line (%d bytes): %q", n, buf[:n])
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
|
n2, err := conn.Read(buf)
|
|
if err == nil && n2 > 0 {
|
|
t.Errorf("Got %d bytes too soon (within 500ms), tarpit failed", n2)
|
|
}
|
|
}
|
|
|
|
func TestEndlesshIntegration_Concurrency(t *testing.T) {
|
|
maxClients := 5
|
|
var stderr bytes.Buffer
|
|
cmd := exec.Command("go", "run", "main.go", "-port=0", "-interval_ms=1000", fmt.Sprintf("-max_clients=%d", maxClients), "-logtostderr", "-v=1")
|
|
cmd.Stderr = &stderr
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
t.Fatalf("Failed to start server: %v", err)
|
|
}
|
|
defer func() {
|
|
if err := cmd.Process.Kill(); err != nil {
|
|
t.Logf("Failed to kill process: %v", err)
|
|
}
|
|
}()
|
|
|
|
if !waitForLogMatch(&stderr, "Listening on", waitForListenTimeout) {
|
|
t.Fatalf("Timeout waiting for server to start, got logs: %s", stderr.String())
|
|
}
|
|
stderrOutput := stderr.String()
|
|
re := regexp.MustCompile(`Listening on .*:(\d+)`)
|
|
m := re.FindStringSubmatch(stderrOutput)
|
|
if len(m) != 2 {
|
|
t.Fatalf("Could not parse port from logs: %s", stderrOutput)
|
|
}
|
|
port := m[1]
|
|
addr := fmt.Sprintf("localhost:%s", port)
|
|
|
|
// Test multiple connections
|
|
var wg sync.WaitGroup
|
|
var mu sync.Mutex
|
|
activeClients := 0
|
|
maxActiveClients := 0
|
|
//sixthClientFailed := false
|
|
successfulReads := 0
|
|
|
|
for i := 0; i < maxClients+1; i++ {
|
|
wg.Add(1)
|
|
go func(clientID int) {
|
|
defer wg.Done()
|
|
|
|
conn, dialErr := net.Dial("tcp", addr)
|
|
if dialErr != nil {
|
|
if clientID == maxClients {
|
|
//sixthClientFailed = true
|
|
t.Logf("Client %d dial failed (expected): %v", clientID, dialErr)
|
|
} else {
|
|
t.Errorf("Client %d dial failed (unexpected): %v", clientID, dialErr)
|
|
}
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, writeErr := conn.Write([]byte("SSH-2.0-OpenSSH_8.2p1\r\n"))
|
|
if writeErr != nil {
|
|
t.Logf("Client %d write failed: %v", clientID, writeErr)
|
|
return
|
|
}
|
|
|
|
buf := make([]byte, 1024)
|
|
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
|
|
n1, readErr1 := conn.Read(buf)
|
|
if readErr1 != nil || n1 == 0 {
|
|
if clientID != maxClients { // only log unexpected failures
|
|
t.Logf("Client %d failed first tarpit line (unexpected): %v (%d bytes)", clientID, readErr1, n1)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Client is active only if it successfully received data
|
|
mu.Lock()
|
|
activeClients++
|
|
if activeClients > maxActiveClients {
|
|
maxActiveClients = activeClients
|
|
}
|
|
successfulReads++
|
|
mu.Unlock()
|
|
|
|
t.Logf("Client %d got line 1 (%d bytes): %q", clientID, n1, buf[:n1])
|
|
|
|
// Keep the connection open for a while
|
|
time.Sleep(5 * time.Second)
|
|
|
|
mu.Lock()
|
|
activeClients--
|
|
mu.Unlock()
|
|
}(i)
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
wg.Wait()
|
|
|
|
mu.Lock()
|
|
deferredSuccessfulReads := successfulReads
|
|
mu.Unlock()
|
|
|
|
if deferredSuccessfulReads < 1 {
|
|
t.Errorf("Expected at least one client to receive a tarpit line, got %d", deferredSuccessfulReads)
|
|
}
|
|
|
|
// Check if maxActiveClients exceeded maxClients
|
|
if maxActiveClients > maxClients {
|
|
t.Errorf("Expected max %d concurrent clients, got %d", maxClients, maxActiveClients)
|
|
}
|
|
|
|
// Check if the 6th client failed
|
|
/*mu.Lock()
|
|
if !sixthClientFailed {
|
|
t.Errorf("Expected 6th client to fail, but it succeeded")
|
|
}
|
|
mu.Unlock()
|
|
*/
|
|
// The 6th client does not necessarily fail at Dial because max_clients is
|
|
// enforced at the protocol/Read stage instead of at Accept.
|
|
//
|
|
// TODO: Enforce max_clients at Accept so the N+1 client fails fast at Dial.
|
|
}
|
|
|
|
func TestEndlesshIntegration_PrometheusMetrics(t *testing.T) {
|
|
var stderr bytes.Buffer
|
|
|
|
cmd := exec.Command(
|
|
"go", "run", "main.go",
|
|
"-port=0",
|
|
"-enable_prometheus",
|
|
"-prometheus_port=0",
|
|
"-interval_ms=100",
|
|
"-logtostderr", "-v=1",
|
|
)
|
|
cmd.Stderr = &stderr
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
t.Fatalf("Failed to start server: %v", err)
|
|
}
|
|
defer cmd.Process.Kill()
|
|
|
|
if !waitForLogMatch(&stderr, "Starting Prometheus", waitForListenTimeout) {
|
|
t.Fatalf("Prometheus did not start: %s", stderr.String())
|
|
}
|
|
|
|
reProm := regexp.MustCompile(`Prometheus on IP port .*:(\d+)`)
|
|
reMain := regexp.MustCompile(`Listening on .*:(\d+)`)
|
|
|
|
promMatch := reProm.FindStringSubmatch(stderr.String())
|
|
mainMatch := reMain.FindStringSubmatch(stderr.String())
|
|
|
|
if len(promMatch) < 2 || len(mainMatch) < 2 {
|
|
t.Fatalf("Could not parse ports: %s", stderr.String())
|
|
}
|
|
|
|
promPort := promMatch[1]
|
|
mainPort := mainMatch[1]
|
|
|
|
conn, err := net.Dial("tcp", "localhost:"+mainPort)
|
|
if err != nil {
|
|
t.Fatalf("Dial failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
conn.Write([]byte("SSH-2.0-test\r\n"))
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
// Fetch metrics
|
|
resp, err := net.Dial("tcp", "localhost:"+promPort)
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to metrics endpoint: %v", err)
|
|
}
|
|
|
|
fmt.Fprintf(resp, "GET /metrics HTTP/1.1\r\nHost: localhost\r\n\r\n")
|
|
|
|
buf := make([]byte, 8192)
|
|
n, _ := resp.Read(buf)
|
|
body := string(buf[:n])
|
|
|
|
if !strings.Contains(body, "endlessh_client_open_count_total") {
|
|
t.Errorf("Missing expected metric in output:\n%s", body)
|
|
}
|
|
|
|
if !strings.Contains(body, "endlessh_sent_bytes_total") {
|
|
t.Errorf("Expected bytes metric not found:\n%s", body)
|
|
}
|
|
}
|