From 2f070772bcdfb6a0946b3b2d4f2bb2c391f2609b Mon Sep 17 00:00:00 2001 From: LBF38 Date: Fri, 27 Mar 2026 16:44:25 +0100 Subject: [PATCH] update tests --- pkg/server/service/transport_test.go | 227 +++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) diff --git a/pkg/server/service/transport_test.go b/pkg/server/service/transport_test.go index 3da10dfd0c..95468282a4 100644 --- a/pkg/server/service/transport_test.go +++ b/pkg/server/service/transport_test.go @@ -1116,3 +1116,230 @@ func TestConnectionTimeouts(t *testing.T) { }) } } + +func TestConnectionTimeoutsAreDefined(t *testing.T) { + testcases := []struct { + desc string + readTimeout ptypes.Duration + writeTimeout ptypes.Duration + expectConnWithTimeout bool + expectedReadTimeout time.Duration + expectedWriteTimeout time.Duration + }{ + { + desc: "read timeout set - should wrap connection with read timeout", + readTimeout: ptypes.Duration(50 * time.Millisecond), + expectConnWithTimeout: true, + expectedReadTimeout: 50 * time.Millisecond, + }, + { + desc: "write timeout set - should wrap connection with write timeout", + writeTimeout: ptypes.Duration(100 * time.Millisecond), + expectConnWithTimeout: true, + expectedWriteTimeout: 100 * time.Millisecond, + }, + { + desc: "both timeouts set - should wrap connection with both timeouts", + readTimeout: ptypes.Duration(30 * time.Millisecond), + writeTimeout: ptypes.Duration(60 * time.Millisecond), + expectConnWithTimeout: true, + expectedReadTimeout: 30 * time.Millisecond, + expectedWriteTimeout: 60 * time.Millisecond, + }, + { + desc: "no timeouts set - should return raw connection", + expectConnWithTimeout: false, + }, + } + + for _, test := range testcases { + t.Run(test.desc, func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + conn.Close() + } + }() + + dialer := &net.Dialer{Timeout: time.Second} + + cfg := &dynamic.ForwardingTimeouts{ + ReadTimeout: test.readTimeout, + WriteTimeout: test.writeTimeout, + } + + dialFn := customDialContext(dialer, cfg) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + conn, err := dialFn(ctx, "tcp", ln.Addr().String()) + require.NoError(t, err, "dial should succeed") + require.NotNil(t, conn) + defer conn.Close() + + if test.expectConnWithTimeout { + wrapped, ok := conn.(*connWithTimeouts) + require.True(t, ok, "expected *connWithTimeouts, got %T", conn) + + if test.expectedReadTimeout > 0 { + require.Equal(t, test.expectedReadTimeout, wrapped.readTimeout, + "read timeout should match configured value") + } + if test.expectedWriteTimeout > 0 { + require.Equal(t, test.expectedWriteTimeout, wrapped.writeTimeout, + "write timeout should match configured value") + } + + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + } else { + _, ok := conn.(*connWithTimeouts) + require.False(t, ok, "should not wrap connection when no timeouts set") + } + }) + } +} + +func TestConnectionTimeoutsAlternative(t *testing.T) { + testcases := []struct { + desc string + readTimeout ptypes.Duration + writeTimeout ptypes.Duration + serverWriteDelay time.Duration + serverReadDelay time.Duration + expectedReadTimeoutErr bool + expectedWriteTimeoutErr bool + }{ + { + desc: "read timeout - server takes longer than timeout", + readTimeout: ptypes.Duration(50 * time.Millisecond), + serverWriteDelay: 200 * time.Millisecond, + expectedReadTimeoutErr: true, + }, + { + desc: "read timeout - server responds within timeout", + readTimeout: ptypes.Duration(200 * time.Millisecond), + serverWriteDelay: 50 * time.Millisecond, + }, + { + desc: "no read timeout - should succeed regardless of delay", + serverWriteDelay: 50 * time.Millisecond, + }, + { + desc: "write timeout - server takes longer than timeout", + writeTimeout: ptypes.Duration(50 * time.Millisecond), + serverReadDelay: 300 * time.Millisecond, + }, + { + desc: "write timeout - server responds within timeout", + writeTimeout: ptypes.Duration(200 * time.Millisecond), + serverReadDelay: 50 * time.Millisecond, + }, + { + desc: "no write timeout - should succeed regardless of delay", + serverReadDelay: 50 * time.Millisecond, + }, + } + + for _, test := range testcases { + t.Run(test.desc, func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "failed to create listener") + defer ln.Close() + + serverErrCh := make(chan error, 1) + go func() { + srvConn, err := ln.Accept() + if err != nil { + serverErrCh <- err + return + } + + if test.serverWriteDelay > 0 { + time.Sleep(test.serverWriteDelay) + } + _, err = srvConn.Write([]byte("HELLO")) + if err != nil { + srvConn.Close() + serverErrCh <- err + return + } + + if test.serverReadDelay > 0 { + time.Sleep(test.serverReadDelay) + } + buf := make([]byte, 1024) + _, err = srvConn.Read(buf) + srvConn.Close() + serverErrCh <- err + }() + + cfg := &dynamic.ForwardingTimeouts{ + ReadTimeout: test.readTimeout, + WriteTimeout: test.writeTimeout, + } + + dialer := &net.Dialer{Timeout: 5 * time.Second} + dialFn := customDialContext(dialer, cfg) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := dialFn(ctx, "tcp", ln.Addr().String()) + require.NoError(t, err, "failed to dial") + require.NotNil(t, conn) + defer conn.Close() + + select { + case err := <-serverErrCh: + require.NoError(t, err, "server error") + default: + } + + if test.readTimeout > 0 { + buf := make([]byte, 5) + readStart := time.Now() + _, err = conn.Read(buf) + elapsed := time.Since(readStart) + + if test.expectedReadTimeoutErr { + require.Error(t, err, "expected read timeout error") + var netErr net.Error + ok := errors.As(err, &netErr) + require.True(t, ok, "expected net.Error, got %T", err) + require.True(t, netErr.Timeout(), "expected timeout error") + require.GreaterOrEqual(t, elapsed, test.readTimeout, + "timeout should trigger after configured duration") + } else if err != nil && !errors.Is(err, io.EOF) { + require.Fail(t, "unexpected error", err) + } + } + + if test.writeTimeout > 0 { + writeStart := time.Now() + _, err = conn.Write([]byte("WORLD")) + elapsed := time.Since(writeStart) + + if test.expectedWriteTimeoutErr { + require.Error(t, err, "expected write timeout error") + var netErr net.Error + ok := errors.As(err, &netErr) + require.True(t, ok, "expected net.Error, got %T", err) + require.True(t, netErr.Timeout(), "expected timeout error") + require.GreaterOrEqual(t, elapsed, test.writeTimeout, + "timeout should trigger after configured duration") + } else if err != nil && !errors.Is(err, io.EOF) { + require.Fail(t, "unexpected error", err) + } + } + }) + } +}