diff --git a/component-tests/cli.py b/component-tests/cli.py index 619ae65d..ac22c54e 100644 --- a/component-tests/cli.py +++ b/component-tests/cli.py @@ -1,10 +1,9 @@ import json import subprocess -from time import sleep from constants import MANAGEMENT_HOST_NAME from setup import get_config_from_file -from util import get_tunnel_connector_id +from util import get_tunnel_connector_id, CloudflaredProcess SINGLE_CASE_TIMEOUT = 600 @@ -83,38 +82,12 @@ class CloudflaredCli: def __enter__(self): self.basecmd += ["run"] - self.process = subprocess.Popen(self.basecmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) self.logger.info(f"Run cmd {self.basecmd}") - return self.process + self.cfd = CloudflaredProcess(self.basecmd, allow_input=False, capture_output=True) + return self.cfd def __exit__(self, exc_type, exc_value, exc_traceback): - terminate_gracefully(self.process, self.logger, self.basecmd) - self.logger.debug(f"{self.basecmd} logs: {self.process.stderr.read()}") - - -def terminate_gracefully(process, logger, cmd): - process.terminate() - process_terminated = wait_for_terminate(process) - if not process_terminated: - process.kill() - logger.warning(f"{cmd}: cloudflared did not terminate within wait period. Killing process. logs: \ - stdout: {process.stdout.read()}, stderr: {process.stderr.read()}") - - -def wait_for_terminate(opened_subprocess, attempts=10, poll_interval=1): - """ - wait_for_terminate polls the opened_subprocess every x seconds for a given number of attempts. - It returns true if the subprocess was terminated and false if it didn't. - """ - for _ in range(attempts): - if _is_process_stopped(opened_subprocess): - return True - sleep(poll_interval) - return False - - -def _is_process_stopped(process): - return process.poll() is not None + self.cfd.cleanup() def cert_path(): diff --git a/component-tests/constants.py b/component-tests/constants.py index 52d96952..1b90ab41 100644 --- a/component-tests/constants.py +++ b/component-tests/constants.py @@ -5,6 +5,17 @@ MAX_LOG_LINES = 50 MANAGEMENT_HOST_NAME = "management.argotunnel.com" +# How long to wait for the cloudflared process to exit after SIGTERM before +# sending SIGKILL. +GRACEFUL_SHUTDOWN_TIMEOUT = 10 +# How long to wait for each pipe reader thread to finish after the process +# exits. +READER_THREAD_JOIN_TIMEOUT = 5 +# How long to wait for an expected log message to appear before giving up. +LOG_POLL_TIMEOUT = 30 +# How often to re-check the accumulated log lines while polling. +LOG_POLL_INTERVAL = 0.5 + def protocols(): return ["http2", "quic"] diff --git a/component-tests/test_edge_discovery.py b/component-tests/test_edge_discovery.py index 8fd3b036..b0fcd6d8 100644 --- a/component-tests/test_edge_discovery.py +++ b/component-tests/test_edge_discovery.py @@ -17,25 +17,6 @@ class TestEdgeDiscovery: config["edge-ip-version"] = edge_ip_version return config - @pytest.mark.parametrize("protocol", protocols()) - def test_default_only(self, tmp_path, component_tests_config, protocol): - """ - This test runs a tunnel with the default edge-ip-version (auto), which will use - whichever address family the system resolver returns first. - """ - if self.has_ipv6_only(): - self.expect_address_connections( - tmp_path, component_tests_config, protocol, None, self.expect_ipv6_address) - elif self.has_ipv4_only(): - self.expect_address_connections( - tmp_path, component_tests_config, protocol, None, self.expect_ipv4_address) - elif self.has_dual_stack(address_family_preference=socket.AddressFamily.AF_INET6): - self.expect_address_connections( - tmp_path, component_tests_config, protocol, None, self.expect_ipv6_address) - else: - self.expect_address_connections( - tmp_path, component_tests_config, protocol, None, self.expect_ipv4_address) - @pytest.mark.parametrize("protocol", protocols()) def test_ipv4_only(self, tmp_path, component_tests_config, protocol): """ diff --git a/component-tests/test_logging.py b/component-tests/test_logging.py index 91af2e58..1607dbcd 100644 --- a/component-tests/test_logging.py +++ b/component-tests/test_logging.py @@ -1,8 +1,9 @@ #!/usr/bin/env python import json import os +import time -from constants import MAX_LOG_LINES +from constants import MAX_LOG_LINES, LOG_POLL_INTERVAL, LOG_POLL_TIMEOUT from util import start_cloudflared, wait_tunnel_ready, send_requests # Rolling logger rotate log files after 1 MB @@ -12,12 +13,14 @@ expect_message = "Starting Hello" def assert_log_to_terminal(cloudflared): - for _ in range(0, MAX_LOG_LINES): - line = cloudflared.stderr.readline() - if not line: - break - if expect_message.encode() in line: - return + # All logs are drained by a background thread into cloudflared.stdout_lines. + # Poll the accumulated lines until the expected message appears. + deadline = time.monotonic() + LOG_POLL_TIMEOUT + while time.monotonic() < deadline: + for line in list(cloudflared.stdout_lines): + if expect_message.encode() in line: + return + time.sleep(LOG_POLL_INTERVAL) raise Exception(f"terminal log doesn't contain {expect_message}") diff --git a/component-tests/util.py b/component-tests/util.py index f45a17fd..21ec8738 100644 --- a/component-tests/util.py +++ b/component-tests/util.py @@ -2,6 +2,7 @@ import logging import os import platform import subprocess +import threading from contextlib import contextmanager from time import sleep import sys @@ -12,7 +13,65 @@ import requests import yaml from retrying import retry -from constants import METRICS_PORT, MAX_RETRIES, BACKOFF_SECS +from constants import METRICS_PORT, MAX_RETRIES, BACKOFF_SECS, GRACEFUL_SHUTDOWN_TIMEOUT, READER_THREAD_JOIN_TIMEOUT + +class CloudflaredProcess: + """ + Wrapper around a Popen process that continuously drains stdout and stderr + in background threads to prevent OS pipe buffers from filling up and + blocking the child process. Captured output is logged when the process + is cleaned up. + """ + + def __init__(self, cmd, allow_input, capture_output): + output = subprocess.PIPE if capture_output else subprocess.DEVNULL + stdin = subprocess.PIPE if allow_input else None + self.process = subprocess.Popen(cmd, stdin=stdin, stdout=output, stderr=subprocess.STDOUT) + + self._capture_output = capture_output + self._stdout_lines = [] + self._threads = [] + if capture_output: + self._threads.append(self._start_reader(self.process.stdout, self._stdout_lines)) + + @staticmethod + def _start_reader(pipe, sink): + def _drain(): + for line in pipe: + sink.append(line) + pipe.close() + t = threading.Thread(target=_drain, daemon=True) + t.start() + return t + + def terminate(self): + """Terminate the process if it is still running.""" + if self.process.poll() is None: + self.process.terminate() + + def cleanup(self): + """Terminate, wait for exit, join reader threads, and log output.""" + self.terminate() + try: + self.process.wait(timeout=GRACEFUL_SHUTDOWN_TIMEOUT) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait() + for t in self._threads: + t.join(timeout=READER_THREAD_JOIN_TIMEOUT) + if self._capture_output: + stdout = b"".join(self._stdout_lines).decode("utf-8", errors="replace") + if stdout: + LOGGER.info(f"cloudflared stdout:\n{stdout}") + + @property + def stdout_lines(self): + return self._stdout_lines + + # Proxy common Popen attributes so callers can still use the wrapper + # as if it were a Popen (e.g. send_signal, stdin, pid, returncode). + def __getattr__(self, name): + return getattr(self.process, name) def configure_logger(): logger = logging.getLogger(__name__) @@ -75,20 +134,15 @@ def cloudflared_cmd(config, config_path, cfd_args, cfd_pre_args, root): LOGGER.info(f"Run cmd {cmd} with config {config}") return cmd - @contextmanager def run_cloudflared_background(cmd, allow_input, capture_output): - output = subprocess.PIPE if capture_output else subprocess.DEVNULL - stdin = subprocess.PIPE if allow_input else None cfd = None try: - cfd = subprocess.Popen(cmd, stdin=stdin, stdout=output, stderr=output) + cfd = CloudflaredProcess(cmd, allow_input, capture_output) yield cfd finally: if cfd: - cfd.terminate() - if capture_output: - LOGGER.info(f"cloudflared log: {cfd.stderr.read()}") + cfd.cleanup() def get_quicktunnel_url():