Fix snicheck with keepalive

This commit is contained in:
Julien Salleyron 2026-06-05 14:36:05 +02:00 committed by GitHub
parent 5404f6fb25
commit b6bb80f8ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 111 additions and 50 deletions

View file

@ -327,7 +327,9 @@ linters:
text: 'SA1008: keys in http.Header are canonicalized, "x-user" is not canonical; fix the constant or use http.CanonicalHeaderKey'
- path: pkg/middlewares/auth/digest_auth_test.go
text: 'SA1008: keys in http.Header are canonicalized, "x-user" is not canonical; fix the constant or use http.CanonicalHeaderKey'
- path: pkg/server/conncontext.go
linters:
- fatcontext
paths:
- pkg/provider/kubernetes/crd/generated/

29
pkg/server/conncontext.go Normal file
View file

@ -0,0 +1,29 @@
package server
import (
"context"
"net"
)
type connContextFunc func(context.Context, net.Conn) context.Context
type multipleConnContext struct {
fns []connContextFunc
}
func (m *multipleConnContext) AddConnContextFunc(fn connContextFunc) {
m.fns = append(m.fns, fn)
}
func (m *multipleConnContext) Build() connContextFunc {
if len(m.fns) == 0 {
return nil
}
return func(ctx context.Context, c net.Conn) context.Context {
for _, contextFunc := range m.fns {
ctx = contextFunc(ctx, c)
}
return ctx
}
}

View file

@ -0,0 +1,26 @@
package server
import (
"context"
"net"
"testing"
"github.com/stretchr/testify/require"
)
type keyTest string
func TestConnContext(t *testing.T) {
var connContext multipleConnContext
connContext.AddConnContextFunc(func(ctx context.Context, _ net.Conn) context.Context {
return context.WithValue(ctx, keyTest("test"), "test")
})
connContext.AddConnContextFunc(func(ctx context.Context, _ net.Conn) context.Context {
return context.WithValue(ctx, keyTest("test2"), "test2")
})
ctx := connContext.Build()(context.Background(), nil)
require.Equal(t, "test", ctx.Value(keyTest("test")))
require.Equal(t, "test2", ctx.Value(keyTest("test2")))
}

View file

@ -114,6 +114,13 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
ctxRouter := log.With(provider.AddInContext(ctx, routerHTTPName), log.Str(log.RouterName, routerHTTPName))
logger := log.FromContext(ctxRouter)
// Even if the TLS options mismatch between the configured and the resolved one is handled in the aggregator
// we also have to handle it here to be able to mark the router in error.
tlsOptionsName := traefiktls.DefaultTLSConfigName
if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName {
tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options)
}
domains, err := httpmuxer.ParseDomains(routerHTTPConfig.Rule)
if err != nil {
routerErr := fmt.Errorf("invalid rule %s, error: %w", routerHTTPConfig.Rule, err)
@ -153,17 +160,14 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
// # When a request for "/foo" comes, even though it won't be routed by httpRouter2,
// # if its SNI is set to foo.com, myTLSOptions will be used for the TLS connection.
// # Otherwise, it will fallback to the default TLS config.
logger.Warnf("No domain found in rule %v, the TLS options applied for this router will depend on the SNI of each request", routerHTTPConfig.Rule)
if tlsOptionsName != traefiktls.DefaultTLSConfigName {
logger.Warnf("No domain found in rule %v, the TLS options applied for this router will depend on the SNI of each request", routerHTTPConfig.Rule)
routerHTTPConfig.AddError(fmt.Errorf("no domain found in rule %v, the TLS option %s cannot be applied", routerHTTPConfig.Rule, tlsOptionsName), false)
}
}
// Even if the TLS options mismatch between the configured and the resolved one is handled in the aggregator
// we also have to handle it here to be able to mark the router in error.
tlsOptionsName := traefiktls.DefaultTLSConfigName
if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName {
tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options)
}
if routerHTTPConfig.TLS.ResolvedOptions != tlsOptionsName {
if len(domains) > 0 && routerHTTPConfig.TLS.ResolvedOptions != tlsOptionsName {
logger.Warn("Found different TLS options for routers on the same host, so using the default TLS options instead.")
routerHTTPConfig.AddError(errors.New("found different TLS options for routers on the same host, so using the default TLS options instead"), false)
}

View file

@ -601,6 +601,44 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
handler = denyFragment(handler)
var connContext multipleConnContext
connContext.AddConnContextFunc(func(ctx context.Context, c net.Conn) context.Context {
// This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM.
ctx = service.AddTransportOnContext(ctx)
if tlsConn, ok := c.(*tls.Conn); ok {
if tlsConnWithOptionsName, ok := tlsConn.NetConn().(tcp.TLSConn); ok {
return tcp.AddTLSOptionsNameInContext(ctx, tlsConnWithOptionsName.TLSOptionsName)
}
}
return ctx
})
if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) {
connContext.AddConnContextFunc(func(ctx context.Context, c net.Conn) context.Context {
cState := &connState{Start: time.Now()}
if debugConnection {
clientConnectionStatesMu.Lock()
clientConnectionStates[getConnKey(c)] = cState
clientConnectionStatesMu.Unlock()
}
return context.WithValue(ctx, connStateKey, cState)
})
}
var connState func(c net.Conn, state http.ConnState)
if debugConnection {
connState = func(c net.Conn, state http.ConnState) {
clientConnectionStatesMu.Lock()
if clientConnectionStates[getConnKey(c)] != nil {
clientConnectionStates[getConnKey(c)].State = state.String()
}
clientConnectionStatesMu.Unlock()
}
}
serverHTTP := &http.Server{
Protocols: &protocols,
Handler: handler,
@ -611,46 +649,8 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
HTTP2: &http.HTTP2Config{
MaxConcurrentStreams: int(configuration.HTTP2.MaxConcurrentStreams),
},
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
if tlsConn, ok := c.(*tls.Conn); ok {
if tlsConnWithOptionsName, ok := tlsConn.NetConn().(tcp.TLSConn); ok {
return tcp.AddTLSOptionsNameInContext(ctx, tlsConnWithOptionsName.TLSOptionsName)
}
}
return ctx
},
}
if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) {
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
cState := &connState{Start: time.Now()}
if debugConnection {
clientConnectionStatesMu.Lock()
clientConnectionStates[getConnKey(c)] = cState
clientConnectionStatesMu.Unlock()
}
return context.WithValue(ctx, connStateKey, cState)
}
if debugConnection {
serverHTTP.ConnState = func(c net.Conn, state http.ConnState) {
clientConnectionStatesMu.Lock()
if clientConnectionStates[getConnKey(c)] != nil {
clientConnectionStates[getConnKey(c)].State = state.String()
}
clientConnectionStatesMu.Unlock()
}
}
}
prevConnContext := serverHTTP.ConnContext
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
// This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM.
ctx = service.AddTransportOnContext(ctx)
if prevConnContext != nil {
return prevConnContext(ctx, c)
}
return ctx
ConnContext: connContext.Build(),
ConnState: connState,
}
listener := newHTTPForwarder(ln)