From a32bef8de598ed0ded4b81961e6b6198909aa557 Mon Sep 17 00:00:00 2001 From: cyphercodes Date: Tue, 28 Apr 2026 05:45:20 +0300 Subject: [PATCH] fix: avoid UDP timeout before first activity --- pkg/udp/conn.go | 22 +++++++----- pkg/udp/conn_test.go | 81 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 93 insertions(+), 10 deletions(-) diff --git a/pkg/udp/conn.go b/pkg/udp/conn.go index 5e50af39c5..90bec10ed7 100644 --- a/pkg/udp/conn.go +++ b/pkg/udp/conn.go @@ -273,10 +273,7 @@ func (c *Conn) readLoop() { case msg := <-c.receiveCh: c.msgs = append(c.msgs, msg) case <-ticker.C: - c.muActivity.RLock() - deadline := c.lastActivity.Add(c.timeout) - c.muActivity.RUnlock() - if time.Now().After(deadline) { + if c.hasTimedOut() { c.Close() return } @@ -293,10 +290,7 @@ func (c *Conn) readLoop() { case msg := <-c.receiveCh: c.msgs = append(c.msgs, msg) case <-ticker.C: - c.muActivity.RLock() - deadline := c.lastActivity.Add(c.timeout) - c.muActivity.RUnlock() - if time.Now().After(deadline) { + if c.hasTimedOut() { c.Close() return } @@ -304,6 +298,18 @@ func (c *Conn) readLoop() { } } +func (c *Conn) hasTimedOut() bool { + c.muActivity.RLock() + lastActivity := c.lastActivity + c.muActivity.RUnlock() + + if lastActivity.IsZero() { + return false + } + + return time.Now().After(lastActivity.Add(c.timeout)) +} + func (c *Conn) close() { c.doneOnce.Do(func() { close(c.doneCh) diff --git a/pkg/udp/conn_test.go b/pkg/udp/conn_test.go index 3e9459102e..2937bf9650 100644 --- a/pkg/udp/conn_test.go +++ b/pkg/udp/conn_test.go @@ -160,12 +160,89 @@ func TestListenWithZeroTimeout(t *testing.T) { assert.Error(t, err) } +func TestTimeoutDoesNotCloseBeforeFirstRead(t *testing.T) { + ln, err := Listen(net.ListenConfig{}, "udp", ":0", time.Millisecond) + require.NoError(t, err) + defer func() { + err := ln.Close() + require.NoError(t, err) + }() + + accepted := make(chan *Conn) + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + accepted <- conn + }() + + udpConn, err := net.Dial("udp", ln.Addr().String()) + require.NoError(t, err) + + _, err = udpConn.Write([]byte("TEST")) + require.NoError(t, err) + + conn := <-accepted + time.Sleep(20 * time.Millisecond) + + type readResult struct { + n int + err error + } + resultCh := make(chan readResult) + go func() { + buf := make([]byte, 2048) + n, err := conn.Read(buf) + if err == nil { + assert.Equal(t, "TEST", string(buf[:n])) + } + resultCh <- readResult{n: n, err: err} + }() + + select { + case result := <-resultCh: + require.NoError(t, result.err) + require.Equal(t, 4, result.n) + case <-time.Tick(time.Second): + t.Fatal("Timeout during first read") + } +} + func TestTimeoutWithRead(t *testing.T) { testTimeout(t, true) } -func TestTimeoutWithoutRead(t *testing.T) { - testTimeout(t, false) +func TestTimeoutAfterFirstRead(t *testing.T) { + ln, err := Listen(net.ListenConfig{}, "udp", ":0", 50*time.Millisecond) + require.NoError(t, err) + defer func() { + err := ln.Close() + require.NoError(t, err) + }() + + accepted := make(chan *Conn) + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + accepted <- conn + }() + + udpConn, err := net.Dial("udp", ln.Addr().String()) + require.NoError(t, err) + + _, err = udpConn.Write([]byte("TEST")) + require.NoError(t, err) + + conn := <-accepted + time.Sleep(2 * ln.timeout) + assert.Len(t, ln.conns, 1) + + buf := make([]byte, 2048) + n, err := conn.Read(buf) + require.NoError(t, err) + require.Equal(t, "TEST", string(buf[:n])) + + time.Sleep(2 * ln.timeout) + assert.Empty(t, ln.conns) } func testTimeout(t *testing.T, withRead bool) {