diff --git a/endlessh_integration_test.go b/endlessh_integration_test.go index 7f86e5f..c177c96 100644 --- a/endlessh_integration_test.go +++ b/endlessh_integration_test.go @@ -6,6 +6,7 @@ import ( "net" "os/exec" "regexp" + "strings" "sync" "testing" "time" @@ -255,3 +256,69 @@ func TestEndlesshIntegration_Concurrency(t *testing.T) { // // TODO: Enforce max_clients at Accept so the N+1 client fails fast at Dial. } + +func TestEndlesshIntegration_PrometheusMetrics(t *testing.T) { + var stderr bytes.Buffer + + cmd := exec.Command( + "go", "run", "main.go", + "-port=0", + "-enable_prometheus", + "-prometheus_port=0", + "-interval_ms=100", + "-logtostderr", "-v=1", + ) + cmd.Stderr = &stderr + + if err := cmd.Start(); err != nil { + t.Fatalf("Failed to start server: %v", err) + } + defer cmd.Process.Kill() + + if !waitForLogMatch(&stderr, "Starting Prometheus", waitForListenTimeout) { + t.Fatalf("Prometheus did not start: %s", stderr.String()) + } + + reProm := regexp.MustCompile(`Prometheus on IP port .*:(\d+)`) + reMain := regexp.MustCompile(`Listening on .*:(\d+)`) + + promMatch := reProm.FindStringSubmatch(stderr.String()) + mainMatch := reMain.FindStringSubmatch(stderr.String()) + + if len(promMatch) < 2 || len(mainMatch) < 2 { + t.Fatalf("Could not parse ports: %s", stderr.String()) + } + + promPort := promMatch[1] + mainPort := mainMatch[1] + + conn, err := net.Dial("tcp", "localhost:"+mainPort) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + conn.Write([]byte("SSH-2.0-test\r\n")) + + time.Sleep(500 * time.Millisecond) + + // Fetch metrics + resp, err := net.Dial("tcp", "localhost:"+promPort) + if err != nil { + t.Fatalf("Failed to connect to metrics endpoint: %v", err) + } + + fmt.Fprintf(resp, "GET /metrics HTTP/1.1\r\nHost: localhost\r\n\r\n") + + buf := make([]byte, 8192) + n, _ := resp.Read(buf) + body := string(buf[:n]) + + if !strings.Contains(body, "endlessh_client_open_count_total") { + t.Errorf("Missing expected metric in output:\n%s", body) + } + + if !strings.Contains(body, "endlessh_sent_bytes_total") { + t.Errorf("Expected bytes metric not found:\n%s", body) + } +} diff --git a/main.go b/main.go index 0436e98..2407950 100644 --- a/main.go +++ b/main.go @@ -156,6 +156,15 @@ func main() { if *connType == "tcp6" && *prometheusHost == "0.0.0.0" { *prometheusHost = "[::]" } + if *prometheusPort == "0" || *prometheusPort == "" { + l, err := net.Listen("tcp", *prometheusHost+":0") + if err != nil { + glog.Fatalf("Failed to pick a free Prometheus port: %v", err) + } + actualPort := l.Addr().(*net.TCPAddr).Port + *prometheusPort = strconv.Itoa(actualPort) + l.Close() + } metrics.InitPrometheus(*prometheusHost, *prometheusPort, *prometheusEntry) }