diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c59b6a802..e998610201 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +## [v2.11.41](https://github.com/traefik/traefik/tree/v2.11.41) (2026-03-18) +[All Commits](https://github.com/traefik/traefik/compare/v2.11.40...v2.11.41) + +**Bug fixes:** +- **[http]** Add maxResponseBodySize configuration on HTTP provider ([#12788](https://github.com/traefik/traefik/pull/12788) @gndz07) +- **[tls]** Support fragmented TLS client hello ([#12787](https://github.com/traefik/traefik/pull/12787) @rtribotte) +- **[middleware, authentication]** Make basic auth check timing constant ([#12803](https://github.com/traefik/traefik/pull/12803) @rtribotte) + +**Documentation:** +- Bump mkdocs-traefiklabs to use consent mode ([#12804](https://github.com/traefik/traefik/pull/12804) @darkweaver87) + ## [v3.6.10](https://github.com/traefik/traefik/tree/v3.6.10) (2026-03-06) [All Commits](https://github.com/traefik/traefik/compare/v3.6.9...v3.6.10) diff --git a/docs/content/middlewares/http/basicauth.md b/docs/content/middlewares/http/basicauth.md index f7017842ec..072183914f 100644 --- a/docs/content/middlewares/http/basicauth.md +++ b/docs/content/middlewares/http/basicauth.md @@ -10,6 +10,12 @@ Adding Basic Authentication The BasicAuth middleware grants access to services to authorized users only. +!!! warning "Timing attacks" + + The BasicAuth middleware is vulnerable to timing attacks when the configured users' password hashes do not use the same algorithm and cost. + However, when the configured user's password hashes are of the same algorithm and cost, the middleware guarantees the same comparison time between existing and non-existing users. + This prevents an attacker from leveraging the time difference to determine whether a user exists. + ## Configuration Examples ```yaml tab="Docker & Swarm" diff --git a/docs/content/providers/http.md b/docs/content/providers/http.md index 775ec88181..8f5c5f4d02 100644 --- a/docs/content/providers/http.md +++ b/docs/content/providers/http.md @@ -200,3 +200,25 @@ providers: ```bash tab="CLI" --providers.http.tls.insecureSkipVerify=true ``` + +### `maxResponseBodySize` + +_Optional, Default=-1_ + +Defines the maximum size of the response body in bytes. +If left unset (or set to -1), the response body size is unrestricted which can have performance implications. + +```yaml tab="File (YAML)" +providers: + http: + maxResponseBodySize: -1 +``` + +```toml tab="File (TOML)" +[providers.http] + maxResponseBodySize = -1 +``` + +```bash tab="CLI" +--providers.http.maxResponseBodySize=-1 +``` diff --git a/docs/content/reference/install-configuration/configuration-options.md b/docs/content/reference/install-configuration/configuration-options.md index 23de297485..51077dfeaf 100644 --- a/docs/content/reference/install-configuration/configuration-options.md +++ b/docs/content/reference/install-configuration/configuration-options.md @@ -321,6 +321,7 @@ THIS FILE MUST NOT BE EDITED BY HAND | providers.http | Enables HTTP provider. | false | | providers.http.endpoint | Load configuration from this endpoint. | | | providers.http.headers._name_ | Define custom headers to be sent to the endpoint. | | +| providers.http.maxresponsebodysize | Defines the maximum size of the response body in bytes. | -1 | | providers.http.pollinterval | Polling interval for endpoint. | 5 | | providers.http.polltimeout | Polling timeout for endpoint. | 5 | | providers.http.tls.ca | TLS CA | | diff --git a/docs/content/reference/static-configuration/cli-ref.md b/docs/content/reference/static-configuration/cli-ref.md index 8dc085557b..43c2dd5376 100644 --- a/docs/content/reference/static-configuration/cli-ref.md +++ b/docs/content/reference/static-configuration/cli-ref.md @@ -933,6 +933,9 @@ Load configuration from this endpoint. `--providers.http.headers.`: Define custom headers to be sent to the endpoint. +`--providers.http.maxresponsebodysize`: +Defines the maximum size of the response body in bytes. (Default: ```-1```) + `--providers.http.pollinterval`: Polling interval for endpoint. (Default: ```5```) diff --git a/docs/content/reference/static-configuration/env-ref.md b/docs/content/reference/static-configuration/env-ref.md index efc26bd0de..f3ad394e2e 100644 --- a/docs/content/reference/static-configuration/env-ref.md +++ b/docs/content/reference/static-configuration/env-ref.md @@ -933,6 +933,9 @@ Load configuration from this endpoint. `TRAEFIK_PROVIDERS_HTTP_HEADERS_`: Define custom headers to be sent to the endpoint. +`TRAEFIK_PROVIDERS_HTTP_MAXRESPONSEBODYSIZE`: +Defines the maximum size of the response body in bytes. (Default: ```-1```) + `TRAEFIK_PROVIDERS_HTTP_POLLINTERVAL`: Polling interval for endpoint. (Default: ```5```) diff --git a/docs/content/reference/static-configuration/file.toml b/docs/content/reference/static-configuration/file.toml index ed9cb0dad3..33dacd1140 100644 --- a/docs/content/reference/static-configuration/file.toml +++ b/docs/content/reference/static-configuration/file.toml @@ -310,6 +310,7 @@ [providers.http.headers] name0 = "foobar" name1 = "foobar" + maxResponseBodySize = 42 [providers.http.tls] ca = "foobar" cert = "foobar" diff --git a/docs/content/reference/static-configuration/file.yaml b/docs/content/reference/static-configuration/file.yaml index 89071c3d19..f2b1d45d8e 100644 --- a/docs/content/reference/static-configuration/file.yaml +++ b/docs/content/reference/static-configuration/file.yaml @@ -354,6 +354,7 @@ providers: cert: foobar key: foobar insecureSkipVerify: true + maxResponseBodySize: 42 plugin: PluginConf0: name0: foobar diff --git a/docs/requirements.txt b/docs/requirements.txt index 7801cf7377..a46901d8e2 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,7 @@ mkdocs==1.4.3 mkdocs-include-markdown-plugin==7.2.0 mkdocs-exclude==1.0.2 -mkdocs-traefiklabs>=100.0.7 +mkdocs-traefiklabs>=100.1.0 mkdocs-redirects==1.2.2 click==8.1.7 diff --git a/pkg/middlewares/auth/basic_auth.go b/pkg/middlewares/auth/basic_auth.go index 863c968f31..18d7995e0e 100644 --- a/pkg/middlewares/auth/basic_auth.go +++ b/pkg/middlewares/auth/basic_auth.go @@ -3,8 +3,10 @@ package auth import ( "context" "fmt" + "maps" "net/http" "net/url" + "slices" "strings" goauth "github.com/abbot/go-http-auth" @@ -27,6 +29,7 @@ type basicAuth struct { removeHeader bool name string + notFoundSecret string checkSecret func(password, secret string) bool singleflightGroup *singleflight.Group } @@ -40,12 +43,18 @@ func NewBasic(ctx context.Context, next http.Handler, authConfig dynamic.BasicAu return nil, err } + // To prevent timing attacks, we need to compute a hash even if the user is not found. + // We assume it to be safe only when the users hashes are all from the same algorithm, + // so we can pick the first one as a random hash to compute. + notFoundSecret := users[slices.Collect(maps.Values(users))[0]] + ba := &basicAuth{ next: next, users: users, headerField: authConfig.HeaderField, removeHeader: authConfig.RemoveHeader, name: name, + notFoundSecret: notFoundSecret, checkSecret: goauth.CheckSecret, singleflightGroup: new(singleflight.Group), } @@ -68,8 +77,9 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logger := middlewares.GetLogger(req.Context(), b.name, typeNameBasic) user, password, ok := req.BasicAuth() + var authenticated bool if ok { - ok = b.checkPassword(user, password) + authenticated = b.checkPassword(user, password) } logData := accesslog.GetLogData(req) @@ -77,7 +87,7 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logData.Core[accesslog.ClientUsername] = user } - if !ok { + if !authenticated { logger.Debug().Msg("Authentication failed") observability.SetStatusErrorf(req.Context(), "Authentication failed") @@ -101,19 +111,21 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (b *basicAuth) checkPassword(user, password string) bool { secret := b.auth.Secrets(user, b.auth.Realm) - if secret == "" { - return false - } key := password + secret match, _, _ := b.singleflightGroup.Do(key, func() (any, error) { + if secret == "" { + _ = b.checkSecret(password, b.notFoundSecret) + return false, nil + } + return b.checkSecret(password, secret), nil }) return match.(bool) } -func (b *basicAuth) secretBasic(user, realm string) string { +func (b *basicAuth) secretBasic(user, _ string) string { if secret, ok := b.users[user]; ok { return secret } diff --git a/pkg/provider/http/http.go b/pkg/provider/http/http.go index 4712e1b0a9..757bd3a3d9 100644 --- a/pkg/provider/http/http.go +++ b/pkg/provider/http/http.go @@ -25,6 +25,8 @@ import ( var _ provider.Provider = (*Provider)(nil) +const defaultMaxResponseBodySize = -1 + // Provider is a provider.Provider implementation that queries an HTTP(s) endpoint for a configuration. type Provider struct { Endpoint string `description:"Load configuration from this endpoint." json:"endpoint" toml:"endpoint" yaml:"endpoint"` @@ -35,12 +37,14 @@ type Provider struct { httpClient *http.Client lastConfigurationHash uint64 + MaxResponseBodySize int64 `description:"Defines the maximum size of the response body in bytes." json:"maxResponseBodySize,omitempty" toml:"maxResponseBodySize,omitempty" yaml:"maxResponseBodySize,omitempty" export:"true"` } // SetDefaults sets the default values. func (p *Provider) SetDefaults() { p.PollInterval = ptypes.Duration(5 * time.Second) p.PollTimeout = ptypes.Duration(5 * time.Second) + p.MaxResponseBodySize = defaultMaxResponseBodySize } // Init the provider. @@ -168,7 +172,19 @@ func (p *Provider) fetchConfigurationData() ([]byte, error) { return nil, fmt.Errorf("received non-ok response code: %d", res.StatusCode) } - return io.ReadAll(res.Body) + if p.MaxResponseBodySize < 0 { + return io.ReadAll(res.Body) + } + + data, err := io.ReadAll(io.LimitReader(res.Body, p.MaxResponseBodySize+1)) + if err != nil { + return nil, fmt.Errorf("reading response body: %w", err) + } + if int64(len(data)) > p.MaxResponseBodySize { + return nil, errors.New("response body too large") + } + + return data, nil } // decodeConfiguration decodes and returns the dynamic configuration from the given data. diff --git a/pkg/provider/http/http_test.go b/pkg/provider/http/http_test.go index eae98d9798..3165ebfb11 100644 --- a/pkg/provider/http/http_test.go +++ b/pkg/provider/http/http_test.go @@ -14,6 +14,7 @@ import ( "github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/safe" "github.com/traefik/traefik/v3/pkg/tls" + "k8s.io/utils/ptr" ) func TestProvider_Init(t *testing.T) { @@ -65,15 +66,17 @@ func TestProvider_SetDefaults(t *testing.T) { assert.Equal(t, provider.PollInterval, ptypes.Duration(5*time.Second)) assert.Equal(t, provider.PollTimeout, ptypes.Duration(5*time.Second)) + assert.Equal(t, int64(-1), provider.MaxResponseBodySize) } func TestProvider_fetchConfigurationData(t *testing.T) { tests := []struct { - desc string - statusCode int - headers map[string]string - expData []byte - expErr require.ErrorAssertionFunc + desc string + statusCode int + headers map[string]string + expData []byte + expErr require.ErrorAssertionFunc + maxResponseBodySize *int64 }{ { desc: "should return the fetched configuration data", @@ -97,6 +100,25 @@ func TestProvider_fetchConfigurationData(t *testing.T) { statusCode: http.StatusInternalServerError, expErr: require.Error, }, + { + desc: "should return an error response body is too long when maxResponseBodySize is 0", + statusCode: http.StatusOK, + maxResponseBodySize: ptr.To(int64(0)), + expErr: require.Error, + }, + { + desc: "should return an error response body is too long when response is longer than maxResponseBodySize", + statusCode: http.StatusOK, + maxResponseBodySize: ptr.To(int64(1)), + expErr: require.Error, + }, + { + desc: "should return the fetched configuration data when response is the same length with maxResponseBodySize", + statusCode: http.StatusOK, + maxResponseBodySize: ptr.To(int64(2)), + expData: []byte("{}"), + expErr: require.NoError, + }, } for _, test := range tests { @@ -118,11 +140,15 @@ func TestProvider_fetchConfigurationData(t *testing.T) { })) defer srv.Close() - provider := Provider{ - Endpoint: srv.URL, - Headers: test.headers, - PollInterval: ptypes.Duration(1 * time.Second), - PollTimeout: ptypes.Duration(1 * time.Second), + var provider Provider + provider.SetDefaults() + + provider.Headers = test.headers + provider.Endpoint = srv.URL + provider.PollTimeout = ptypes.Duration(1 * time.Second) + provider.PollInterval = ptypes.Duration(1 * time.Second) + if test.maxResponseBodySize != nil { + provider.MaxResponseBodySize = *test.maxResponseBodySize } err := provider.Init() @@ -201,11 +227,12 @@ func TestProvider_Provide(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - provider := Provider{ - Endpoint: server.URL, - PollTimeout: ptypes.Duration(1 * time.Second), - PollInterval: ptypes.Duration(100 * time.Millisecond), - } + var provider Provider + provider.SetDefaults() + + provider.Endpoint = server.URL + provider.PollTimeout = ptypes.Duration(1 * time.Second) + provider.PollInterval = ptypes.Duration(100 * time.Millisecond) err := provider.Init() require.NoError(t, err) @@ -257,11 +284,12 @@ func TestProvider_ProvideConfigurationOnlyOnceIfUnchanged(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(handler)) defer server.Close() - provider := Provider{ - Endpoint: server.URL + "/endpoint", - PollTimeout: ptypes.Duration(1 * time.Second), - PollInterval: ptypes.Duration(100 * time.Millisecond), - } + var provider Provider + provider.SetDefaults() + + provider.Endpoint = server.URL + "/endpoint" + provider.PollTimeout = ptypes.Duration(1 * time.Second) + provider.PollInterval = ptypes.Duration(100 * time.Millisecond) err := provider.Init() require.NoError(t, err) diff --git a/pkg/server/router/tcp/postgres.go b/pkg/server/router/tcp/postgres.go index dedb7d6f4e..d71cb7f465 100644 --- a/pkg/server/router/tcp/postgres.go +++ b/pkg/server/router/tcp/postgres.go @@ -51,7 +51,8 @@ func (r *Router) servePostgres(conn tcp.WriteCloser) { return } - br := bufio.NewReader(conn) + var peeked bytes.Buffer + br := bufio.NewReader(io.TeeReader(conn, &peeked)) b := make([]byte, len(PostgresStartTLSMsg)) _, err = br.Read(b) @@ -93,7 +94,7 @@ func (r *Router) servePostgres(conn tcp.WriteCloser) { } // We are in TLS mode and if the handler is not TLSHandler, we are in passthrough. - proxiedConn := r.GetConn(conn, hello.peeked) + proxiedConn := r.GetConn(conn, peeked.String()) if _, ok := handlerTCPTLS.(*tcp.TLSHandler); !ok { proxiedConn = &postgresConn{WriteCloser: proxiedConn} } diff --git a/pkg/server/router/tcp/router.go b/pkg/server/router/tcp/router.go index 81400e3f1e..b59c172dbc 100644 --- a/pkg/server/router/tcp/router.go +++ b/pkg/server/router/tcp/router.go @@ -19,15 +19,8 @@ import ( "github.com/traefik/traefik/v3/pkg/tcp" ) -const ( - defaultBufSize = 4096 - // Per RFC 8446 Section 5.1, the maximum TLS record payload length is 2^14 (16384) bytes. - // A ClientHello is always a plaintext record, so any value exceeding this limit is invalid - // and likely indicates an attack attempting to force oversized per-connection buffer allocations. - // However, in practice the go server handshake can read up to 16384 + 2048 bytes, - // so we need to allow for some extra bytes to avoid rejecting valid handshakes. - maxTLSRecordLen = 16384 + 2048 -) +// errClientHelloRead is used as a sentinel error to break the TLS handshake once we have read the ClientHello. +var errClientHelloRead = errors.New("client hello successfully read") // Router is a TCP router. type Router struct { @@ -127,7 +120,9 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { } // TODO -- Check if ProxyProtocol changes the first bytes of the request - br := bufio.NewReader(conn) + var peeked bytes.Buffer + br := bufio.NewReader(io.TeeReader(conn, &peeked)) + postgres, err := isPostgres(br) if err != nil { conn.Close() @@ -135,7 +130,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { } if postgres { - r.servePostgres(r.GetConn(conn, getPeeked(br))) + r.servePostgres(r.GetConn(conn, peeked.String())) return } @@ -168,9 +163,9 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { handler, _ := r.muxerTCP.Match(connData) switch { case handler != nil: - handler.ServeTCP(r.GetConn(conn, hello.peeked)) + handler.ServeTCP(r.GetConn(conn, peeked.String())) case r.httpForwarder != nil: - r.httpForwarder.ServeTCP(r.GetConn(conn, hello.peeked)) + r.httpForwarder.ServeTCP(r.GetConn(conn, peeked.String())) default: conn.Close() } @@ -179,7 +174,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { // Handling ACME-TLS/1 challenges. if !r.acmeTLSPassthrough && slices.Contains(hello.protos, tlsalpn01.ACMETLS1Protocol) { - r.acmeTLSALPNHandler().ServeTCP(r.GetConn(conn, hello.peeked)) + r.acmeTLSALPNHandler().ServeTCP(r.GetConn(conn, peeked.String())) return } @@ -193,14 +188,14 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { // In order not to depart from the behavior in 2.6, // we only allow an HTTPS router to take precedence over a TCP-TLS router if it is _not_ an HostSNI(*) router // (so basically any router that has a specific HostSNI based rule). - handlerHTTPS.ServeTCP(r.GetConn(conn, hello.peeked)) + handlerHTTPS.ServeTCP(r.GetConn(conn, peeked.String())) return } // Contains also TCP TLS passthrough routes. handlerTCPTLS, catchAllTCPTLS := r.muxerTCPTLS.Match(connData) if handlerTCPTLS != nil && !catchAllTCPTLS { - handlerTCPTLS.ServeTCP(r.GetConn(conn, hello.peeked)) + handlerTCPTLS.ServeTCP(r.GetConn(conn, peeked.String())) return } @@ -208,19 +203,19 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) { // We end up here for e.g. an HTTPS router that only has a PathPrefix rule, // which under the scenes is counted as an HostSNI(*) rule. if handlerHTTPS != nil { - handlerHTTPS.ServeTCP(r.GetConn(conn, hello.peeked)) + handlerHTTPS.ServeTCP(r.GetConn(conn, peeked.String())) return } // Fallback on TCP TLS catchAll. if handlerTCPTLS != nil { - handlerTCPTLS.ServeTCP(r.GetConn(conn, hello.peeked)) + handlerTCPTLS.ServeTCP(r.GetConn(conn, peeked.String())) return } // To handle 404s for HTTPS. if r.httpsForwarder != nil { - r.httpsForwarder.ServeTCP(r.GetConn(conn, hello.peeked)) + r.httpsForwarder.ServeTCP(r.GetConn(conn, peeked.String())) return } @@ -375,7 +370,6 @@ type clientHello struct { serverName string // SNI server name protos []string // ALPN protocols list isTLS bool // whether we are a TLS handshake - peeked string // the bytes peeked from the hello while getting the info } // clientHelloInfo returns various data from the clientHello handshake, @@ -396,74 +390,46 @@ func clientHelloInfo(br *bufio.Reader) (*clientHello, error) { if hdr[0] == recordTypeSSLv2 { // we consider SSLv2 as TLS, and it will be refused by real TLS handshake. return &clientHello{ - isTLS: true, - peeked: getPeeked(br), + isTLS: true, }, nil } - return &clientHello{ - peeked: getPeeked(br), - }, nil // Not TLS. + return &clientHello{}, nil // Not TLS. } - const recordHeaderLen = 5 - hdr, err = br.Peek(recordHeaderLen) - if err != nil { - return nil, fmt.Errorf("peeking client hello headers: %w", err) - } - - recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] - - if recLen > maxTLSRecordLen { - return nil, fmt.Errorf("peeking client hello bytes, oversized record: %d", recLen) - } - - if recordHeaderLen+recLen > defaultBufSize { - br = bufio.NewReaderSize(br, recordHeaderLen+recLen) - } - - helloBytes, err := br.Peek(recordHeaderLen + recLen) - if err != nil { - return nil, fmt.Errorf("peeking client hello bytes: %w", err) - } - - sni := "" - var protos []string - server := tls.Server(helloSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ + var ( + sni string + protos []string + ) + server := tls.Server(readOnlyConn{r: br}, &tls.Config{ GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { sni = hello.ServerName protos = hello.SupportedProtos - return nil, nil + // This error prevents unnecessary additional steps in the TLS ClientHello message processing. + return nil, errClientHelloRead }, }) - _ = server.Handshake() + + if handshakeErr := server.Handshake(); !errors.Is(handshakeErr, errClientHelloRead) { + return nil, fmt.Errorf("reading client hello: %w", handshakeErr) + } return &clientHello{ serverName: sni, isTLS: true, - peeked: getPeeked(br), protos: protos, }, nil } -func getPeeked(br *bufio.Reader) string { - peeked, err := br.Peek(br.Buffered()) - if err != nil { - log.Error().Err(err).Msg("Error while peeking bytes") - return "" - } - return string(peeked) -} - -// helloSniffConn is a net.Conn that reads from r, fails on Writes, +// readOnlyConn is a net.Conn that reads from r, fails on Writes, // and crashes otherwise. -type helloSniffConn struct { +type readOnlyConn struct { net.Conn // nil; crash on any unexpected use r io.Reader } // Read reads from the underlying reader. -func (c helloSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (c readOnlyConn) Read(p []byte) (int, error) { return c.r.Read(p) } // Write crashes all the time. -func (helloSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } +func (readOnlyConn) Write(_ []byte) (int, error) { return 0, io.EOF } diff --git a/pkg/server/router/tcp/router_test.go b/pkg/server/router/tcp/router_test.go index 0634c0f60a..7292de8a69 100644 --- a/pkg/server/router/tcp/router_test.go +++ b/pkg/server/router/tcp/router_test.go @@ -1151,12 +1151,34 @@ func Test_clientHelloInfo_oversizedRecordLength(t *testing.T) { } } -// Test_clientHelloInfo_validRecordLength verifies that clientHelloInfo -// still works correctly with legitimate TLS record sizes. -func Test_clientHelloInfo_validRecordLength(t *testing.T) { +// Test_clientHelloInfo_tlsRecordFragmentation documents a known limitation: +// clientHelloInfo only reads a single TLS record. When a ClientHello handshake +// message is split across multiple TLS records (RFC 5246 ยง6.2.1), the SNI cannot +// be extracted, leaving serverName empty and allowing SNI-based routing to be bypassed. +func Test_clientHelloInfo_tlsRecordFragmentation(t *testing.T) { + serverName := "foo.example.com" + record := buildClientHelloRecord(t, serverName) + + const hdrLen = 5 + payload := record[hdrLen:] + + ver1, ver2 := record[1], record[2] + + var recordsData bytes.Buffer + for _, part := range [][]byte{payload[:len(serverName)/2], payload[len(serverName)/2:]} { + recordsData.WriteByte(0x16) + recordsData.WriteByte(ver1) + recordsData.WriteByte(ver2) + recordsData.WriteByte(byte(len(part) >> 8)) + recordsData.WriteByte(byte(len(part))) + recordsData.Write(part) + } + serverConn, clientConn := net.Pipe() - defer serverConn.Close() - defer clientConn.Close() + t.Cleanup(func() { + _ = serverConn.Close() + _ = clientConn.Close() + }) type result struct { hello *clientHello @@ -1170,30 +1192,51 @@ func Test_clientHelloInfo_validRecordLength(t *testing.T) { resultCh <- result{hello, err} }() - // Build a TLS record header with a small (valid) record length. - recLen := 100 - hdr := []byte{ - 0x16, // Content Type: Handshake - 0x03, 0x03, // Version: TLS 1.2 - byte(recLen >> 8), // Length high byte - byte(recLen & 0xFF), // Length low byte - } - payload := make([]byte, recLen) - - _, err := clientConn.Write(append(hdr, payload...)) + _, err := clientConn.Write(recordsData.Bytes()) require.NoError(t, err) - clientConn.Close() + _ = clientConn.Close() select { case r := <-resultCh: require.NoError(t, r.err) require.NotNil(t, r.hello) assert.True(t, r.hello.isTLS) + assert.Equal(t, serverName, r.hello.serverName) case <-time.After(5 * time.Second): - t.Fatal("clientHelloInfo blocked on valid TLS record") + t.Fatal("clientHelloInfo blocked") } } +// buildClientHelloRecord captures a real TLS ClientHello record from Go's TLS stack +// for the given serverName. +// It returns the raw record bytes and the byte offset of the SNI value within those bytes. +func buildClientHelloRecord(t *testing.T, serverName string) []byte { + t.Helper() + + serverConn, clientConn := net.Pipe() + + recordCh := make(chan []byte, 1) + go func() { + buf := make([]byte, 65536) + n, _ := serverConn.Read(buf) + _ = serverConn.Close() + recordCh <- buf[:n] + }() + + go func() { + tlsConn := tls.Client(clientConn, &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + _ = clientConn.Close() + }() + + record := <-recordCh + + return record +} + func TestPostgres(t *testing.T) { router, err := NewRouter() require.NoError(t, err)