From b6bb80f8ff7e564f47acebade9f78e01f689264d Mon Sep 17 00:00:00 2001 From: Julien Salleyron Date: Fri, 5 Jun 2026 14:36:05 +0200 Subject: [PATCH] Fix snicheck with keepalive --- .golangci.yml | 4 +- pkg/server/conncontext.go | 29 +++++++++++ pkg/server/conncontext_test.go | 26 ++++++++++ pkg/server/router/tcp/manager.go | 22 ++++---- pkg/server/server_entrypoint_tcp.go | 80 ++++++++++++++--------------- 5 files changed, 111 insertions(+), 50 deletions(-) create mode 100644 pkg/server/conncontext.go create mode 100644 pkg/server/conncontext_test.go diff --git a/.golangci.yml b/.golangci.yml index 02749c267d..feefe53c34 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -327,7 +327,9 @@ linters: text: 'SA1008: keys in http.Header are canonicalized, "x-user" is not canonical; fix the constant or use http.CanonicalHeaderKey' - path: pkg/middlewares/auth/digest_auth_test.go text: 'SA1008: keys in http.Header are canonicalized, "x-user" is not canonical; fix the constant or use http.CanonicalHeaderKey' - + - path: pkg/server/conncontext.go + linters: + - fatcontext paths: - pkg/provider/kubernetes/crd/generated/ diff --git a/pkg/server/conncontext.go b/pkg/server/conncontext.go new file mode 100644 index 0000000000..a93290c7a2 --- /dev/null +++ b/pkg/server/conncontext.go @@ -0,0 +1,29 @@ +package server + +import ( + "context" + "net" +) + +type connContextFunc func(context.Context, net.Conn) context.Context + +type multipleConnContext struct { + fns []connContextFunc +} + +func (m *multipleConnContext) AddConnContextFunc(fn connContextFunc) { + m.fns = append(m.fns, fn) +} + +func (m *multipleConnContext) Build() connContextFunc { + if len(m.fns) == 0 { + return nil + } + + return func(ctx context.Context, c net.Conn) context.Context { + for _, contextFunc := range m.fns { + ctx = contextFunc(ctx, c) + } + return ctx + } +} diff --git a/pkg/server/conncontext_test.go b/pkg/server/conncontext_test.go new file mode 100644 index 0000000000..6160ea358e --- /dev/null +++ b/pkg/server/conncontext_test.go @@ -0,0 +1,26 @@ +package server + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +type keyTest string + +func TestConnContext(t *testing.T) { + var connContext multipleConnContext + connContext.AddConnContextFunc(func(ctx context.Context, _ net.Conn) context.Context { + return context.WithValue(ctx, keyTest("test"), "test") + }) + connContext.AddConnContextFunc(func(ctx context.Context, _ net.Conn) context.Context { + return context.WithValue(ctx, keyTest("test2"), "test2") + }) + + ctx := connContext.Build()(context.Background(), nil) + + require.Equal(t, "test", ctx.Value(keyTest("test"))) + require.Equal(t, "test2", ctx.Value(keyTest("test2"))) +} diff --git a/pkg/server/router/tcp/manager.go b/pkg/server/router/tcp/manager.go index 0a17dfb272..de9e9e2762 100644 --- a/pkg/server/router/tcp/manager.go +++ b/pkg/server/router/tcp/manager.go @@ -114,6 +114,13 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string ctxRouter := log.With(provider.AddInContext(ctx, routerHTTPName), log.Str(log.RouterName, routerHTTPName)) logger := log.FromContext(ctxRouter) + // Even if the TLS options mismatch between the configured and the resolved one is handled in the aggregator + // we also have to handle it here to be able to mark the router in error. + tlsOptionsName := traefiktls.DefaultTLSConfigName + if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName { + tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options) + } + domains, err := httpmuxer.ParseDomains(routerHTTPConfig.Rule) if err != nil { routerErr := fmt.Errorf("invalid rule %s, error: %w", routerHTTPConfig.Rule, err) @@ -153,17 +160,14 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string // # When a request for "/foo" comes, even though it won't be routed by httpRouter2, // # if its SNI is set to foo.com, myTLSOptions will be used for the TLS connection. // # Otherwise, it will fallback to the default TLS config. - logger.Warnf("No domain found in rule %v, the TLS options applied for this router will depend on the SNI of each request", routerHTTPConfig.Rule) + if tlsOptionsName != traefiktls.DefaultTLSConfigName { + logger.Warnf("No domain found in rule %v, the TLS options applied for this router will depend on the SNI of each request", routerHTTPConfig.Rule) + routerHTTPConfig.AddError(fmt.Errorf("no domain found in rule %v, the TLS option %s cannot be applied", routerHTTPConfig.Rule, tlsOptionsName), false) + } } - // Even if the TLS options mismatch between the configured and the resolved one is handled in the aggregator - // we also have to handle it here to be able to mark the router in error. - tlsOptionsName := traefiktls.DefaultTLSConfigName - if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName { - tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options) - } - - if routerHTTPConfig.TLS.ResolvedOptions != tlsOptionsName { + if len(domains) > 0 && routerHTTPConfig.TLS.ResolvedOptions != tlsOptionsName { + logger.Warn("Found different TLS options for routers on the same host, so using the default TLS options instead.") routerHTTPConfig.AddError(errors.New("found different TLS options for routers on the same host, so using the default TLS options instead"), false) } diff --git a/pkg/server/server_entrypoint_tcp.go b/pkg/server/server_entrypoint_tcp.go index 39dd3b42fe..252ec0514a 100644 --- a/pkg/server/server_entrypoint_tcp.go +++ b/pkg/server/server_entrypoint_tcp.go @@ -601,6 +601,44 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati handler = denyFragment(handler) + var connContext multipleConnContext + connContext.AddConnContextFunc(func(ctx context.Context, c net.Conn) context.Context { + // This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM. + ctx = service.AddTransportOnContext(ctx) + + if tlsConn, ok := c.(*tls.Conn); ok { + if tlsConnWithOptionsName, ok := tlsConn.NetConn().(tcp.TLSConn); ok { + return tcp.AddTLSOptionsNameInContext(ctx, tlsConnWithOptionsName.TLSOptionsName) + } + } + + return ctx + }) + + if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) { + connContext.AddConnContextFunc(func(ctx context.Context, c net.Conn) context.Context { + cState := &connState{Start: time.Now()} + if debugConnection { + clientConnectionStatesMu.Lock() + clientConnectionStates[getConnKey(c)] = cState + clientConnectionStatesMu.Unlock() + } + + return context.WithValue(ctx, connStateKey, cState) + }) + } + + var connState func(c net.Conn, state http.ConnState) + if debugConnection { + connState = func(c net.Conn, state http.ConnState) { + clientConnectionStatesMu.Lock() + if clientConnectionStates[getConnKey(c)] != nil { + clientConnectionStates[getConnKey(c)].State = state.String() + } + clientConnectionStatesMu.Unlock() + } + } + serverHTTP := &http.Server{ Protocols: &protocols, Handler: handler, @@ -611,46 +649,8 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati HTTP2: &http.HTTP2Config{ MaxConcurrentStreams: int(configuration.HTTP2.MaxConcurrentStreams), }, - ConnContext: func(ctx context.Context, c net.Conn) context.Context { - if tlsConn, ok := c.(*tls.Conn); ok { - if tlsConnWithOptionsName, ok := tlsConn.NetConn().(tcp.TLSConn); ok { - return tcp.AddTLSOptionsNameInContext(ctx, tlsConnWithOptionsName.TLSOptionsName) - } - } - - return ctx - }, - } - if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) { - serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - cState := &connState{Start: time.Now()} - if debugConnection { - clientConnectionStatesMu.Lock() - clientConnectionStates[getConnKey(c)] = cState - clientConnectionStatesMu.Unlock() - } - return context.WithValue(ctx, connStateKey, cState) - } - - if debugConnection { - serverHTTP.ConnState = func(c net.Conn, state http.ConnState) { - clientConnectionStatesMu.Lock() - if clientConnectionStates[getConnKey(c)] != nil { - clientConnectionStates[getConnKey(c)].State = state.String() - } - clientConnectionStatesMu.Unlock() - } - } - } - - prevConnContext := serverHTTP.ConnContext - serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - // This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM. - ctx = service.AddTransportOnContext(ctx) - if prevConnContext != nil { - return prevConnContext(ctx, c) - } - return ctx + ConnContext: connContext.Build(), + ConnState: connState, } listener := newHTTPForwarder(ln)