diff --git a/client/client.go b/client/client.go index ef29cd0..3fca5ff 100644 --- a/client/client.go +++ b/client/client.go @@ -41,16 +41,15 @@ func randStringBytes(n int64) []byte { } type Client struct { - conn net.Conn - next time.Time - start time.Time - last time.Time - interval time.Duration - writeDeadline time.Duration - bytesSent int + conn net.Conn + next time.Time + start time.Time + last time.Time + interval time.Duration + bytesSent int } -func NewClient(conn net.Conn, interval time.Duration, maxClients int64, writeDeadline time.Duration) *Client { +func NewClient(conn net.Conn, interval time.Duration, maxClients int64) *Client { for numCurrentClients >= maxClients { time.Sleep(interval) } @@ -63,13 +62,12 @@ func NewClient(conn net.Conn, interval time.Duration, maxClients int64, writeDea addr := conn.RemoteAddr().(*net.TCPAddr) glog.V(1).Infof("ACCEPT host=%v port=%v n=%v/%v\n", addr.IP, addr.Port, numCurrentClients, maxClients) return &Client{ - conn: conn, - next: time.Now().Add(interval), - start: time.Now(), - last: time.Now(), - interval: interval, - writeDeadline: writeDeadline, - bytesSent: 0, + conn: conn, + next: time.Now().Add(interval), + start: time.Now(), + last: time.Now(), + interval: interval, + bytesSent: 0, } } @@ -90,9 +88,7 @@ func (c *Client) Send(bannerMaxLength int64) (int, error) { // Set a write deadline to detect dead connections where the kernel // buffers data but the remote peer is gone. Without this, Write() // can succeed indefinitely on dead connections, causing goroutine leaks. - if c.writeDeadline > 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeDeadline)) - } + c.conn.SetWriteDeadline(time.Now().Add(c.interval)) bytesSent, err := c.conn.Write(randStringBytes(length)) if err != nil { return 0, err diff --git a/main.go b/main.go index 32a2cd7..6eb2d80 100644 --- a/main.go +++ b/main.go @@ -76,7 +76,7 @@ func startSending(maxClients int64, bannerMaxLength int64, records chan<- metric return clients } -func startAccepting(maxClients int64, connType, connHost, connPort string, interval time.Duration, writeDeadline time.Duration, clients chan<- *client.Client, records chan<- metrics.RecordEntry, proxyProtocolEnabled bool, proxyProtocolReadHeaderTimeout int) { +func startAccepting(maxClients int64, connType, connHost, connPort string, interval time.Duration, clients chan<- *client.Client, records chan<- metrics.RecordEntry, proxyProtocolEnabled bool, proxyProtocolReadHeaderTimeout int) { go func() { connPortInt, err := strconv.Atoi(connPort) if err != nil { @@ -109,7 +109,7 @@ func startAccepting(maxClients int64, connType, connHost, connPort string, inter glog.Errorf("Error accepting connection from port %v: %v", connPort, err) os.Exit(1) } - c := client.NewClient(conn, interval, maxClients, writeDeadline) + c := client.NewClient(conn, interval, maxClients) remoteIpAddr := c.RemoteIpAddr() records <- metrics.RecordEntry{ RecordType: metrics.RecordEntryTypeStart, @@ -150,7 +150,6 @@ func main() { prometheusCleanUnseenSeconds := flag.Int("prometheus_clean_unseen_seconds", 0, "Remove series if the IP is not seen for the given time. Set to 0 to disable. (default 0)") geoipSupplier := flag.String("geoip_supplier", "off", "Supplier to obtain Geohash of IPs. Possible values are \"off\", \"ip-api\", \"max-mind-db\"") maxMindDbFileName := flag.String("max_mind_db", "", "Path to the MaxMind DB file.") - writeDeadlineMs := flag.Int("write_deadline_ms", 30000, "Write deadline in milliseconds for sending tarpit data. Detects dead connections where the kernel buffers data but the remote peer is gone. Set to 0 to disable. (default 30000)") proxyProtocolEnabled := flag.Bool("proxy_protocol_enabled", false, "Enable PROXY protocol support. This causes the server to expect PROXY protocol headers on incoming connections.") proxyProtocolReadHeaderTimeout := flag.Int("proxy_protocol_read_header_timeout_ms", 200, "Timeout for reading the PROXY protocol header in milliseconds. If the connection does not send a valid PROXY protocol header in this time, the header is ignored.") @@ -175,7 +174,6 @@ func main() { clients := startSending(*maxClients, *bannerMaxLength, records) interval := time.Duration(*intervalMs) * time.Millisecond - writeDeadline := time.Duration(*writeDeadlineMs) * time.Millisecond // Listen for incoming connections. if *connType == "tcp6" && *connHost == "0.0.0.0" { *connHost = "[::]" @@ -184,7 +182,7 @@ func main() { connPorts = append(connPorts, defaultPort) } for _, connPort := range connPorts { - startAccepting(*maxClients, *connType, *connHost, connPort, interval, writeDeadline, clients, records, *proxyProtocolEnabled, *proxyProtocolReadHeaderTimeout) + startAccepting(*maxClients, *connType, *connHost, connPort, interval, clients, records, *proxyProtocolEnabled, *proxyProtocolReadHeaderTimeout) } for { if *prometheusCleanUnseenSeconds <= 0 {