refactor: use interval as write deadline instead of separate parameter

Simplify the ghost connection fix by reusing the existing interval
duration as the write deadline. This removes the need for a separate
-write_deadline_ms flag while maintaining the same protection against
goroutine leaks from dead connections.
This commit is contained in:
darkwolf
2026-03-12 09:26:38 +01:00
parent ecdfc514d0
commit b6b3fe2678
2 changed files with 17 additions and 23 deletions
+14 -18
View File
@@ -41,16 +41,15 @@ func randStringBytes(n int64) []byte {
}
type Client struct {
conn net.Conn
next time.Time
start time.Time
last time.Time
interval time.Duration
writeDeadline time.Duration
bytesSent int
conn net.Conn
next time.Time
start time.Time
last time.Time
interval time.Duration
bytesSent int
}
func NewClient(conn net.Conn, interval time.Duration, maxClients int64, writeDeadline time.Duration) *Client {
func NewClient(conn net.Conn, interval time.Duration, maxClients int64) *Client {
for numCurrentClients >= maxClients {
time.Sleep(interval)
}
@@ -63,13 +62,12 @@ func NewClient(conn net.Conn, interval time.Duration, maxClients int64, writeDea
addr := conn.RemoteAddr().(*net.TCPAddr)
glog.V(1).Infof("ACCEPT host=%v port=%v n=%v/%v\n", addr.IP, addr.Port, numCurrentClients, maxClients)
return &Client{
conn: conn,
next: time.Now().Add(interval),
start: time.Now(),
last: time.Now(),
interval: interval,
writeDeadline: writeDeadline,
bytesSent: 0,
conn: conn,
next: time.Now().Add(interval),
start: time.Now(),
last: time.Now(),
interval: interval,
bytesSent: 0,
}
}
@@ -90,9 +88,7 @@ func (c *Client) Send(bannerMaxLength int64) (int, error) {
// Set a write deadline to detect dead connections where the kernel
// buffers data but the remote peer is gone. Without this, Write()
// can succeed indefinitely on dead connections, causing goroutine leaks.
if c.writeDeadline > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeDeadline))
}
c.conn.SetWriteDeadline(time.Now().Add(c.interval))
bytesSent, err := c.conn.Write(randStringBytes(length))
if err != nil {
return 0, err
+3 -5
View File
@@ -76,7 +76,7 @@ func startSending(maxClients int64, bannerMaxLength int64, records chan<- metric
return clients
}
func startAccepting(maxClients int64, connType, connHost, connPort string, interval time.Duration, writeDeadline time.Duration, clients chan<- *client.Client, records chan<- metrics.RecordEntry, proxyProtocolEnabled bool, proxyProtocolReadHeaderTimeout int) {
func startAccepting(maxClients int64, connType, connHost, connPort string, interval time.Duration, clients chan<- *client.Client, records chan<- metrics.RecordEntry, proxyProtocolEnabled bool, proxyProtocolReadHeaderTimeout int) {
go func() {
connPortInt, err := strconv.Atoi(connPort)
if err != nil {
@@ -109,7 +109,7 @@ func startAccepting(maxClients int64, connType, connHost, connPort string, inter
glog.Errorf("Error accepting connection from port %v: %v", connPort, err)
os.Exit(1)
}
c := client.NewClient(conn, interval, maxClients, writeDeadline)
c := client.NewClient(conn, interval, maxClients)
remoteIpAddr := c.RemoteIpAddr()
records <- metrics.RecordEntry{
RecordType: metrics.RecordEntryTypeStart,
@@ -150,7 +150,6 @@ func main() {
prometheusCleanUnseenSeconds := flag.Int("prometheus_clean_unseen_seconds", 0, "Remove series if the IP is not seen for the given time. Set to 0 to disable. (default 0)")
geoipSupplier := flag.String("geoip_supplier", "off", "Supplier to obtain Geohash of IPs. Possible values are \"off\", \"ip-api\", \"max-mind-db\"")
maxMindDbFileName := flag.String("max_mind_db", "", "Path to the MaxMind DB file.")
writeDeadlineMs := flag.Int("write_deadline_ms", 30000, "Write deadline in milliseconds for sending tarpit data. Detects dead connections where the kernel buffers data but the remote peer is gone. Set to 0 to disable. (default 30000)")
proxyProtocolEnabled := flag.Bool("proxy_protocol_enabled", false, "Enable PROXY protocol support. This causes the server to expect PROXY protocol headers on incoming connections.")
proxyProtocolReadHeaderTimeout := flag.Int("proxy_protocol_read_header_timeout_ms", 200, "Timeout for reading the PROXY protocol header in milliseconds. If the connection does not send a valid PROXY protocol header in this time, the header is ignored.")
@@ -175,7 +174,6 @@ func main() {
clients := startSending(*maxClients, *bannerMaxLength, records)
interval := time.Duration(*intervalMs) * time.Millisecond
writeDeadline := time.Duration(*writeDeadlineMs) * time.Millisecond
// Listen for incoming connections.
if *connType == "tcp6" && *connHost == "0.0.0.0" {
*connHost = "[::]"
@@ -184,7 +182,7 @@ func main() {
connPorts = append(connPorts, defaultPort)
}
for _, connPort := range connPorts {
startAccepting(*maxClients, *connType, *connHost, connPort, interval, writeDeadline, clients, records, *proxyProtocolEnabled, *proxyProtocolReadHeaderTimeout)
startAccepting(*maxClients, *connType, *connHost, connPort, interval, clients, records, *proxyProtocolEnabled, *proxyProtocolReadHeaderTimeout)
}
for {
if *prometheusCleanUnseenSeconds <= 0 {