diff --git a/docs/content/reference/routing-configuration/http/load-balancing/service.md b/docs/content/reference/routing-configuration/http/load-balancing/service.md index 7039ac3db8..131fc2691e 100644 --- a/docs/content/reference/routing-configuration/http/load-balancing/service.md +++ b/docs/content/reference/routing-configuration/http/load-balancing/service.md @@ -760,7 +760,7 @@ The `mirroring` service type mirrors requests sent to a service to other service !!! info "Supported Providers" This service type can be defined currently with the [File](../../../install-configuration/providers/others/file.md) provider or [IngressRoute](../../../routing-configuration/kubernetes/crd/http/ingressroute.md). - + ```yaml tab="Structured (YAML)" ## Routing configuration http: @@ -887,9 +887,13 @@ http: url = "http://private-ip-server-2/" ``` -### Failover +### Failover -The `failover` service type forwards all requests to a fallback service when the main service becomes unreachable. +The `failover` service type forwards requests to a fallback service when the main service is unavailable. +Failover can be triggered in two ways: + +- **Health check-based**: When the main service becomes unreachable based on [health checks](#health-check). +- **Status code-based**: When the main service responds with specific HTTP status codes defined in the [errors](#errors) configuration. !!! info "Relation to HealthCheck" The failover service relies on the HealthCheck system to get notified when its main service becomes unreachable, which means HealthCheck needs to be enabled and functional on the main service. However, HealthCheck does not need to be enabled on the failover service itself for it to be functional. It is only required in order to propagate upwards the information when the failover itself becomes down (i.e. both its main and its fallback are down too). @@ -940,15 +944,15 @@ http: ## Routing configuration [http.services] [http.services.app] - [http.services.app.failover.healthCheck] [http.services.app.failover] service = "main" fallback = "backup" + [http.services.app.failover.healthCheck] [http.services.main] [http.services.main.loadBalancer] [http.services.main.loadBalancer.healthCheck] - path = "/health" + path = "/status" interval = "10s" timeout = "3s" [[http.services.main.loadBalancer.servers]] @@ -957,9 +961,163 @@ http: [http.services.backup] [http.services.backup.loadBalancer] [http.services.backup.loadBalancer.healthCheck] - path = "/health" + path = "/status" interval = "10s" timeout = "3s" [[http.services.backup.loadBalancer.servers]] url = "http://private-ip-server-2/" ``` + +#### Errors + +The `errors` option enables status code-based failover. +When the main service responds with an HTTP status code matching one of the configured ranges, Traefik automatically retries the request on the fallback service. + +To support request replay, the request body is buffered up to `maxRequestBodyBytes`. +Requests with bodies larger than this limit receive a `413 Request Entity Too Large` response. + +Below is a list of options available for the `errors` option and an example of how to configure it for a failover service: + +| Field | Description | Default | +|-----------------------|-------------------------------------------------------------------------------------------------------------------|---------| +| `status` | List of HTTP status code ranges that trigger failover. Supports single codes (`"500"`) and ranges (`"500-504"`). | None | +| `maxRequestBodyBytes` | Maximum request body size (in bytes) to buffer for replay to the fallback service. Set to `-1` for no limit. | `-1` | + +```yaml tab="Structured (YAML)" +## Routing configuration +http: + services: + app: + failover: + service: main + fallback: backup + errors: + status: + - "500-504" + maxRequestBodyBytes: 1048576 + + main: + loadBalancer: + servers: + - url: "http://private-ip-server-1/" + + backup: + loadBalancer: + servers: + - url: "http://private-ip-server-2/" +``` + +```toml tab="Structured (TOML)" +## Routing configuration +[http.services] + [http.services.app] + [http.services.app.failover] + service = "main" + fallback = "backup" + [http.services.app.failover.errors] + status = ["500-504"] + maxRequestBodyBytes = 1048576 + + [http.services.main] + [http.services.main.loadBalancer] + [[http.services.main.loadBalancer.servers]] + url = "http://private-ip-server-1/" + + [http.services.backup] + [http.services.backup.loadBalancer] + [[http.services.backup.loadBalancer.servers]] + url = "http://private-ip-server-2/" +``` + +#### Chaining Failover Services + +Failover services can be chained together for multi-level redundancy. +In the following example, if the primary service fails, traffic goes to the secondary service. +If both primary and secondary fail, traffic goes to the tertiary service. + +```yaml tab="Structured (YAML)" +## Routing configuration +http: + services: + app: + failover: + healthCheck: {} + service: primary-failover + fallback: tertiary + + primary-failover: + failover: + healthCheck: {} + service: primary + fallback: secondary + + primary: + loadBalancer: + healthCheck: + path: /health + interval: 10s + timeout: 3s + servers: + - url: "http://primary-server/" + + secondary: + loadBalancer: + healthCheck: + path: /health + interval: 10s + timeout: 3s + servers: + - url: "http://secondary-server/" + + tertiary: + loadBalancer: + healthCheck: + path: /health + interval: 10s + timeout: 3s + servers: + - url: "http://tertiary-server/" +``` + +```toml tab="Structured (TOML)" +## Routing configuration +[http.services] + [http.services.app] + [http.services.app.failover] + service = "primary-failover" + fallback = "tertiary" + [http.services.app.failover.healthCheck] + + [http.services.primary-failover] + [http.services.primary-failover.failover] + service = "primary" + fallback = "secondary" + [http.services.primary-failover.failover.healthCheck] + + [http.services.primary] + [http.services.primary.loadBalancer] + [http.services.primary.loadBalancer.healthCheck] + path = "/health" + interval = "10s" + timeout = "3s" + [[http.services.primary.loadBalancer.servers]] + url = "http://primary-server/" + + [http.services.secondary] + [http.services.secondary.loadBalancer] + [http.services.secondary.loadBalancer.healthCheck] + path = "/health" + interval = "10s" + timeout = "3s" + [[http.services.secondary.loadBalancer.servers]] + url = "http://secondary-server/" + + [http.services.tertiary] + [http.services.tertiary.loadBalancer] + [http.services.tertiary.loadBalancer.healthCheck] + path = "/health" + interval = "10s" + timeout = "3s" + [[http.services.tertiary.loadBalancer.servers]] + url = "http://tertiary-server/" +``` diff --git a/docs/content/reference/routing-configuration/other-providers/file.toml b/docs/content/reference/routing-configuration/other-providers/file.toml index 2f94a07ddb..994b342dc3 100644 --- a/docs/content/reference/routing-configuration/other-providers/file.toml +++ b/docs/content/reference/routing-configuration/other-providers/file.toml @@ -56,6 +56,9 @@ service = "foobar" fallback = "foobar" [http.services.Service01.failover.healthCheck] + [http.services.Service01.failover.errors] + maxRequestBodyBytes = 42 + status = ["foobar", "foobar"] [http.services.Service02] [http.services.Service02.highestRandomWeight] diff --git a/docs/content/reference/routing-configuration/other-providers/file.yaml b/docs/content/reference/routing-configuration/other-providers/file.yaml index 50200a4fb1..0166d98351 100644 --- a/docs/content/reference/routing-configuration/other-providers/file.yaml +++ b/docs/content/reference/routing-configuration/other-providers/file.yaml @@ -70,6 +70,11 @@ http: service: foobar fallback: foobar healthCheck: {} + errors: + maxRequestBodyBytes: 42 + status: + - foobar + - foobar Service02: highestRandomWeight: services: diff --git a/integration/fixtures/failover.toml b/integration/fixtures/failover.toml new file mode 100644 index 0000000000..ae13a938ca --- /dev/null +++ b/integration/fixtures/failover.toml @@ -0,0 +1,53 @@ +[global] + checkNewVersion = false + sendAnonymousUsage = false + +[api] + insecure = true + +[log] + level = "DEBUG" + noColor = true + +[entryPoints] + [entryPoints.web] + address = ":8000" + +[providers.file] + filename = "{{ .SelfFilename }}" + +## dynamic configuration ## + +[http.routers] + [http.routers.router1] + entrypoints = ["web"] + service = "failover-service" + rule = "Path(`/whoami`)" + +[http.services] + # Failover service with health check + [http.services.failover-service] + [http.services.failover-service.failover] + service = "main-service" + fallback = "fallback-service" + [http.services.failover-service.failover.healthCheck] + + # Main service with health check enabled + [http.services.main-service] + [http.services.main-service.loadBalancer] + [http.services.main-service.loadBalancer.healthCheck] + path = "/health" + interval = "1s" + timeout = "0.9s" + [[http.services.main-service.loadBalancer.servers]] + url = "http://{{ .MainServer }}:80" + + # Fallback service with health check enabled + [http.services.fallback-service] + [http.services.fallback-service.loadBalancer] + [http.services.fallback-service.loadBalancer.healthCheck] + path = "/health" + interval = "1s" + timeout = "0.9s" + [[http.services.fallback-service.loadBalancer.servers]] + url = "http://{{ .FallbackServer }}:80" \ No newline at end of file diff --git a/integration/fixtures/failover_statuscode.toml b/integration/fixtures/failover_statuscode.toml new file mode 100644 index 0000000000..11d04336f4 --- /dev/null +++ b/integration/fixtures/failover_statuscode.toml @@ -0,0 +1,47 @@ +[global] + checkNewVersion = false + sendAnonymousUsage = false + +[api] + insecure = true + +[log] + level = "DEBUG" + noColor = true + +[entryPoints] + [entryPoints.web] + address = ":8000" + +[providers.file] + filename = "{{ .SelfFilename }}" + +## dynamic configuration ## + +[http.routers] + [http.routers.router1] + entrypoints = ["web"] + service = "failover-service" + rule = "PathPrefix(`/`)" + +[http.services] + # Failover service with error-based failover (status codes) + [http.services.failover-service] + [http.services.failover-service.failover] + service = "main-service" + fallback = "fallback-service" + [http.services.failover-service.failover.errors] + status = ["500-504"] + maxRequestBodyBytes = 1048576 # 1MB + + # Main service (no health check - failover based on status codes only) + [http.services.main-service] + [http.services.main-service.loadBalancer] + [[http.services.main-service.loadBalancer.servers]] + url = "{{ .MainServer }}" + + # Fallback service + [http.services.fallback-service] + [http.services.fallback-service.loadBalancer] + [[http.services.fallback-service.loadBalancer.servers]] + url = "{{ .FallbackServer }}" \ No newline at end of file diff --git a/integration/simple_test.go b/integration/simple_test.go index 1d1e1a8185..312eb5943e 100644 --- a/integration/simple_test.go +++ b/integration/simple_test.go @@ -2365,6 +2365,173 @@ func (s *SimpleSuite) TestEncodedCharactersDifferentEntryPoints() { } } +func (s *SimpleSuite) TestFailoverService() { + s.createComposeProject("base") + + s.composeUp() + defer s.composeDown() + + whoami1IP := s.getComposeServiceIP("whoami1") + whoami2IP := s.getComposeServiceIP("whoami2") + + file := s.adaptFile("fixtures/failover.toml", struct { + MainServer string + FallbackServer string + }{ + MainServer: whoami1IP, + FallbackServer: whoami2IP, + }) + + s.traefikCmd(withConfigFile(file)) + + // Wait for Traefik to be ready + err := try.GetRequest("http://127.0.0.1:8080/api/http/services", 2*time.Second, try.BodyContains("failover-service")) + require.NoError(s.T(), err) + + // Test 1: When main service is healthy, traffic should go to main + var primaryCount, fallbackCount int + for range 5 { + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil) + require.NoError(s.T(), err) + + response, err := http.DefaultClient.Do(req) + require.NoError(s.T(), err) + assert.Equal(s.T(), http.StatusOK, response.StatusCode) + + body, err := io.ReadAll(response.Body) + require.NoError(s.T(), err) + + if strings.Contains(string(body), whoami1IP) { + primaryCount++ + } + if strings.Contains(string(body), whoami2IP) { + fallbackCount++ + } + } + + // All requests should go to the main service (whoami1) + assert.Equal(s.T(), 5, primaryCount, "Expected all requests to go to main service") + assert.Equal(s.T(), 0, fallbackCount, "Expected no requests to go to fallback service") + + // Test 2: Stop the main service to trigger failover via health check + s.composeStop("whoami1") + + // Wait for health check to detect the main service is down + time.Sleep(3 * time.Second) + + // Now all traffic should go to the fallback service (whoami2) + primaryCount, fallbackCount = 0, 0 + for range 5 { + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil) + require.NoError(s.T(), err) + + response, err := http.DefaultClient.Do(req) + require.NoError(s.T(), err) + assert.Equal(s.T(), http.StatusOK, response.StatusCode) + + body, err := io.ReadAll(response.Body) + require.NoError(s.T(), err) + + if strings.Contains(string(body), whoami1IP) { + primaryCount++ + } + if strings.Contains(string(body), whoami2IP) { + fallbackCount++ + } + } + + assert.Equal(s.T(), 0, primaryCount, "Expected no requests to go to main service when down") + assert.Equal(s.T(), 5, fallbackCount, "Expected all requests to go to fallback service") + + // Test 3: Restart main service and verify traffic returns to main + s.composeUp("whoami1") + + // Wait for health check to detect the main service is back up + time.Sleep(3 * time.Second) + + primaryCount, fallbackCount = 0, 0 + for range 5 { + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil) + require.NoError(s.T(), err) + + response, err := http.DefaultClient.Do(req) + require.NoError(s.T(), err) + assert.Equal(s.T(), http.StatusOK, response.StatusCode) + + body, err := io.ReadAll(response.Body) + require.NoError(s.T(), err) + + if strings.Contains(string(body), whoami1IP) { + primaryCount++ + } + if strings.Contains(string(body), whoami2IP) { + fallbackCount++ + } + } + + // Traffic should return to the main service + assert.Equal(s.T(), 5, primaryCount, "Expected all requests to return to main service when back up") + assert.Equal(s.T(), 0, fallbackCount, "Expected no requests to go to fallback service") + + // Test 4: Stop both services and verify we get 503 + s.composeStop("whoami1") + s.composeStop("whoami2") + + // Wait for health checks to detect both services are down + time.Sleep(3 * time.Second) + + // Request should return 503 Service Unavailable when both services are down + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil) + require.NoError(s.T(), err) + + response, err := http.DefaultClient.Do(req) + require.NoError(s.T(), err) + assert.Equal(s.T(), http.StatusServiceUnavailable, response.StatusCode) +} + +func (s *SimpleSuite) TestFailoverServiceWithStatusCode() { + var mainCallCount, fallbackCallCount atomic.Int32 + + // Create a test server that returns 503 to trigger error-based failover + mainServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + mainCallCount.Add(1) + rw.WriteHeader(http.StatusServiceUnavailable) + _, _ = rw.Write([]byte("main service unavailable")) + })) + defer mainServer.Close() + + // Create a fallback server that returns 200 + fallbackServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + fallbackCallCount.Add(1) + rw.WriteHeader(http.StatusOK) + + _, _ = rw.Write([]byte("fallback service")) + })) + defer fallbackServer.Close() + + file := s.adaptFile("fixtures/failover_statuscode.toml", struct { + MainServer string + FallbackServer string + }{ + MainServer: mainServer.URL, + FallbackServer: fallbackServer.URL, + }) + + s.traefikCmd(withConfigFile(file)) + + // Wait for Traefik to be ready and verify the configuration is loaded + err := try.GetRequest("http://127.0.0.1:8080/api/rawdata", 10*time.Second, try.BodyContains("PathPrefix")) + require.NoError(s.T(), err) + + // Make a request - should failover to fallback because main returns 503 + err = try.GetRequest("http://127.0.0.1:8000/", 5*time.Second, try.BodyContains("fallback service")) + require.NoError(s.T(), err) + + // Main was called but returned 503, triggering failover to fallback + assert.GreaterOrEqual(s.T(), mainCallCount.Load(), int32(1), "Main service should have been called at least once") + assert.GreaterOrEqual(s.T(), fallbackCallCount.Load(), int32(1), "Fallback service should have been called at least once") +} + func (s *SimpleSuite) TestServiceMiddleware() { s.createComposeProject("base") diff --git a/pkg/config/dynamic/http_config.go b/pkg/config/dynamic/http_config.go index a58143daa3..30db7c3cce 100644 --- a/pkg/config/dynamic/http_config.go +++ b/pkg/config/dynamic/http_config.go @@ -10,6 +10,7 @@ import ( traefiktls "github.com/traefik/traefik/v3/pkg/tls" "github.com/traefik/traefik/v3/pkg/types" "google.golang.org/grpc/codes" + "k8s.io/utils/ptr" ) const ( @@ -28,6 +29,8 @@ const ( MirroringDefaultMirrorBody = true // MirroringDefaultMaxBodySize is the Mirroring.MaxBodySize option default value. MirroringDefaultMaxBodySize int64 = -1 + // FailoverErrorsDefaultMaxRequestBodyBytes is the Failover.Errors.MaxBodySize option default value. + FailoverErrorsDefaultMaxRequestBodyBytes int64 = -1 ) // +k8s:deepcopy-gen=true @@ -195,9 +198,23 @@ func (m *Mirroring) SetDefaults() { // Failover holds the Failover configuration. type Failover struct { - Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty" export:"true"` - Fallback string `json:"fallback,omitempty" toml:"fallback,omitempty" yaml:"fallback,omitempty" export:"true"` - HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:"healthCheck,omitempty" yaml:"healthCheck,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"` + Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty" export:"true"` + Fallback string `json:"fallback,omitempty" toml:"fallback,omitempty" yaml:"fallback,omitempty" export:"true"` + HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:"healthCheck,omitempty" yaml:"healthCheck,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"` + Errors *FailoverError `json:"errors,omitempty" toml:"errors,omitempty" yaml:"errors,omitempty" export:"true"` +} + +// +k8s:deepcopy-gen=true + +// FailoverError holds errors configuration. +type FailoverError struct { + MaxRequestBodyBytes *int64 `json:"maxRequestBodyBytes,omitempty" toml:"maxRequestBodyBytes,omitempty" yaml:"maxRequestBodyBytes,omitempty" export:"true"` + Status []string `json:"status,omitempty" toml:"status,omitempty" yaml:"status,omitempty" export:"true"` +} + +// SetDefaults Default values for a WRRService. +func (m *FailoverError) SetDefaults() { + m.MaxRequestBodyBytes = ptr.To(FailoverErrorsDefaultMaxRequestBodyBytes) } // +k8s:deepcopy-gen=true diff --git a/pkg/config/dynamic/zz_generated.deepcopy.go b/pkg/config/dynamic/zz_generated.deepcopy.go index 9651754533..e9375771af 100644 --- a/pkg/config/dynamic/zz_generated.deepcopy.go +++ b/pkg/config/dynamic/zz_generated.deepcopy.go @@ -358,6 +358,11 @@ func (in *Failover) DeepCopyInto(out *Failover) { *out = new(HealthCheck) **out = **in } + if in.Errors != nil { + in, out := &in.Errors, &out.Errors + *out = new(FailoverError) + (*in).DeepCopyInto(*out) + } return } @@ -371,6 +376,32 @@ func (in *Failover) DeepCopy() *Failover { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *FailoverError) DeepCopyInto(out *FailoverError) { + *out = *in + if in.MaxRequestBodyBytes != nil { + in, out := &in.MaxRequestBodyBytes, &out.MaxRequestBodyBytes + *out = new(int64) + **out = **in + } + if in.Status != nil { + in, out := &in.Status, &out.Status + *out = make([]string, len(*in)) + copy(*out, *in) + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FailoverError. +func (in *FailoverError) DeepCopy() *FailoverError { + if in == nil { + return nil + } + out := new(FailoverError) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) { *out = *in diff --git a/pkg/server/service/loadbalancer/failover/failover.go b/pkg/server/service/loadbalancer/failover/failover.go index 07cf519684..4237047680 100644 --- a/pkg/server/service/loadbalancer/failover/failover.go +++ b/pkg/server/service/loadbalancer/failover/failover.go @@ -3,11 +3,14 @@ package failover import ( "context" "errors" + "fmt" "net/http" "sync" "github.com/rs/zerolog/log" "github.com/traefik/traefik/v3/pkg/config/dynamic" + "github.com/traefik/traefik/v3/pkg/server/service/loadbalancer/mirror" + "github.com/traefik/traefik/v3/pkg/types" ) // Failover is an http.Handler that can forward requests to the fallback handler @@ -25,13 +28,33 @@ type Failover struct { fallbackStatusMu sync.RWMutex fallbackStatus bool + + statusCode types.HTTPCodeRanges + maxRequestBodyBytes int64 } // New creates a new Failover handler. -func New(hc *dynamic.HealthCheck) *Failover { - return &Failover{ - wantsHealthCheck: hc != nil, +func New(config *dynamic.Failover) (*Failover, error) { + f := &Failover{wantsHealthCheck: config.HealthCheck != nil} + + if config.Errors != nil { + if len(config.Errors.Status) > 0 { + httpCodeRanges, err := types.NewHTTPCodeRanges(config.Errors.Status) + if err != nil { + return nil, fmt.Errorf("creating HTTP code ranges: %w", err) + } + f.statusCode = httpCodeRanges + } + + maxRequestBodyBytes := dynamic.FailoverErrorsDefaultMaxRequestBodyBytes + if config.Errors.MaxRequestBodyBytes != nil { + maxRequestBodyBytes = *config.Errors.MaxRequestBodyBytes + } + + f.maxRequestBodyBytes = maxRequestBodyBytes } + + return f, nil } // RegisterStatusUpdater adds fn to the list of hooks that are run when the @@ -53,8 +76,33 @@ func (f *Failover) ServeHTTP(w http.ResponseWriter, req *http.Request) { f.handlerStatusMu.RUnlock() if handlerStatus { - f.handler.ServeHTTP(w, req) - return + if len(f.statusCode) == 0 { + f.handler.ServeHTTP(w, req) + + return + } + + // TODO: move reusable request to a common package at some point. + rr, _, err := mirror.NewReusableRequest(req, f.maxRequestBodyBytes) + if err != nil && !errors.Is(err, mirror.ErrBodyTooLarge) { + log.Ctx(req.Context()).Debug().Err(err).Msg("Error while creating reusable request for failover") + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + if errors.Is(err, mirror.ErrBodyTooLarge) { + http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) + return + } + + rw := &responseWriter{ResponseWriter: w, statusCodeRange: f.statusCode} + f.handler.ServeHTTP(rw, rr.Clone(req.Context())) + + if !rw.needFallback { + return + } + + req = rr.Clone(req.Context()) } f.fallbackStatusMu.RLock() diff --git a/pkg/server/service/loadbalancer/failover/failover_test.go b/pkg/server/service/loadbalancer/failover/failover_test.go index 8ca46e2515..c88733a17b 100644 --- a/pkg/server/service/loadbalancer/failover/failover_test.go +++ b/pkg/server/service/loadbalancer/failover/failover_test.go @@ -1,6 +1,8 @@ package failover import ( + "bytes" + "io" "net/http" "net/http/httptest" "testing" @@ -26,7 +28,10 @@ func (r *responseRecorder) WriteHeader(statusCode int) { } func TestFailover(t *testing.T) { - failover := New(&dynamic.HealthCheck{}) + failover, err := New(&dynamic.Failover{ + HealthCheck: &dynamic.HealthCheck{}, + }) + require.NoError(t, err) status := true require.NoError(t, failover.RegisterStatusUpdater(func(up bool) { @@ -73,7 +78,8 @@ func TestFailover(t *testing.T) { } func TestFailoverDownThenUp(t *testing.T) { - failover := New(nil) + failover, err := New(&dynamic.Failover{}) + require.NoError(t, err) failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "handler") @@ -112,7 +118,10 @@ func TestFailoverDownThenUp(t *testing.T) { } func TestFailoverPropagate(t *testing.T) { - failover := New(&dynamic.HealthCheck{}) + failover, err := New(&dynamic.Failover{ + HealthCheck: &dynamic.HealthCheck{}, + }) + require.NoError(t, err) failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "handler") rw.WriteHeader(http.StatusOK) @@ -122,13 +131,14 @@ func TestFailoverPropagate(t *testing.T) { rw.WriteHeader(http.StatusOK) })) - topFailover := New(nil) + topFailover, err := New(&dynamic.Failover{}) + require.NoError(t, err) topFailover.SetHandler(failover) topFailover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("server", "topFailover") rw.WriteHeader(http.StatusOK) })) - err := failover.RegisterStatusUpdater(func(up bool) { + err = failover.RegisterStatusUpdater(func(up bool) { topFailover.SetHandlerStatus(t.Context(), up) }) require.NoError(t, err) @@ -161,3 +171,301 @@ func TestFailoverPropagate(t *testing.T) { assert.Equal(t, 1, recorder.save["topFailover"]) assert.Equal(t, []int{200}, recorder.status) } + +func TestFailoverStatusCode(t *testing.T) { + testCases := []struct { + desc string + statusCode []string + handlerStatusCode int + expectedHandler string + expectedStatusCode int + expectedResponseBody string + }{ + { + desc: "main handler returns 503, failover triggered", + statusCode: []string{"503"}, + handlerStatusCode: 503, + expectedHandler: "fallback", + expectedStatusCode: 200, + expectedResponseBody: "fallback response", + }, + { + desc: "main handler returns 200, failover not triggered", + statusCode: []string{"503"}, + handlerStatusCode: 200, + expectedHandler: "handler", + expectedStatusCode: 200, + expectedResponseBody: "main response", + }, + { + desc: "main handler returns 500, failover not triggered", + statusCode: []string{"503"}, + handlerStatusCode: 500, + expectedHandler: "handler", + expectedStatusCode: 500, + expectedResponseBody: "main response", + }, + { + desc: "multiple status codes, 503 triggers failover", + statusCode: []string{"500", "502", "503", "504"}, + handlerStatusCode: 503, + expectedHandler: "fallback", + expectedStatusCode: 200, + expectedResponseBody: "fallback response", + }, + { + desc: "multiple status codes, 502 triggers failover", + statusCode: []string{"500", "502", "503", "504"}, + handlerStatusCode: 502, + expectedHandler: "fallback", + expectedStatusCode: 200, + expectedResponseBody: "fallback response", + }, + { + desc: "multiple status codes, 404 does not trigger failover", + statusCode: []string{"500", "502", "503", "504"}, + handlerStatusCode: 404, + expectedHandler: "handler", + expectedStatusCode: 404, + expectedResponseBody: "main response", + }, + { + desc: "status code range 500-504, 503 triggers failover", + statusCode: []string{"500-504"}, + handlerStatusCode: 503, + expectedHandler: "fallback", + expectedStatusCode: 200, + expectedResponseBody: "fallback response", + }, + { + desc: "status code range 500-504, 404 does not trigger failover", + statusCode: []string{"500-504"}, + handlerStatusCode: 404, + expectedHandler: "handler", + expectedStatusCode: 404, + expectedResponseBody: "main response", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + failover, err := New(&dynamic.Failover{ + Errors: &dynamic.FailoverError{ + Status: test.statusCode, + }, + }) + require.NoError(t, err) + + failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", "handler") + rw.WriteHeader(test.handlerStatusCode) + _, err := rw.Write([]byte("main response")) + require.NoError(t, err) + })) + + failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", "fallback") + rw.WriteHeader(http.StatusOK) + _, err := rw.Write([]byte("fallback response")) + require.NoError(t, err) + })) + + recorder := httptest.NewRecorder() + failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + + assert.Equal(t, test.expectedHandler, recorder.Header().Get("server")) + assert.Equal(t, test.expectedStatusCode, recorder.Code) + assert.Equal(t, test.expectedResponseBody, recorder.Body.String()) + }) + } +} + +func TestFailoverStatusCodeWithRequestBody(t *testing.T) { + testCases := []struct { + desc string + statusCode []string + handlerStatusCode int + requestBody string + expectedHandler string + }{ + { + desc: "request body replayed to fallback handler", + statusCode: []string{"503"}, + handlerStatusCode: 503, + requestBody: "test request body", + expectedHandler: "fallback", + }, + { + desc: "request body used by main handler", + statusCode: []string{"503"}, + handlerStatusCode: 200, + requestBody: "test request body", + expectedHandler: "handler", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + var receivedBody string + + maxBody := int64(-1) + failover, err := New(&dynamic.Failover{ + Errors: &dynamic.FailoverError{ + Status: test.statusCode, + MaxRequestBodyBytes: &maxBody, + }, + }) + require.NoError(t, err) + + failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + body, _ := io.ReadAll(req.Body) + receivedBody = string(body) + rw.Header().Set("server", "handler") + rw.WriteHeader(test.handlerStatusCode) + })) + + failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + body, _ := io.ReadAll(req.Body) + receivedBody = string(body) + rw.Header().Set("server", "fallback") + rw.WriteHeader(http.StatusOK) + })) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(test.requestBody)) + failover.ServeHTTP(recorder, req) + + assert.Equal(t, test.expectedHandler, recorder.Header().Get("server")) + assert.Equal(t, test.requestBody, receivedBody) + }) + } +} + +func TestFailoverStatusCodeMaxBodySize(t *testing.T) { + testCases := []struct { + desc string + maxBodySize int64 + requestBody string + expectedStatusCode int + expectedMessage string + }{ + { + desc: "request body within limit", + maxBodySize: 100, + requestBody: "small body", + expectedStatusCode: 200, + expectedMessage: "", + }, + { + desc: "request body exceeds limit", + maxBodySize: 5, + requestBody: "this body is too large", + expectedStatusCode: http.StatusRequestEntityTooLarge, + expectedMessage: "Request body too large\n", + }, + { + desc: "zero body size limit", + maxBodySize: 0, + requestBody: "any size body", + expectedStatusCode: http.StatusRequestEntityTooLarge, + expectedMessage: "Request body too large\n", + }, + { + desc: "no body size limit", + maxBodySize: -1, + requestBody: "any size body should work", + expectedStatusCode: 200, + expectedMessage: "", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + maxBody := test.maxBodySize + failover, err := New(&dynamic.Failover{ + Errors: &dynamic.FailoverError{ + Status: []string{"503"}, + MaxRequestBodyBytes: &maxBody, + }, + }) + require.NoError(t, err) + + failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusServiceUnavailable) + })) + + failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(test.requestBody)) + req.ContentLength = int64(len(test.requestBody)) + failover.ServeHTTP(recorder, req) + + assert.Equal(t, test.expectedStatusCode, recorder.Code) + if test.expectedMessage != "" { + assert.Equal(t, test.expectedMessage, recorder.Body.String()) + } + }) + } +} + +func TestFailoverStatusCodeWithHealthCheck(t *testing.T) { + failover, err := New(&dynamic.Failover{ + HealthCheck: &dynamic.HealthCheck{}, + Errors: &dynamic.FailoverError{ + Status: []string{"503"}, + }, + }) + require.NoError(t, err) + + failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", "handler") + rw.WriteHeader(http.StatusServiceUnavailable) + })) + + failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", "fallback") + rw.WriteHeader(http.StatusOK) + })) + + // Test 1: Handler is up, returns 503, should failover based on status code + recorder := httptest.NewRecorder() + failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + + assert.Equal(t, "fallback", recorder.Header().Get("server")) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Test 2: Handler is marked down via health check, should failover + failover.SetHandlerStatus(t.Context(), false) + + recorder = httptest.NewRecorder() + failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + + assert.Equal(t, "fallback", recorder.Header().Get("server")) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Test 3: Handler is back up but returns non-503 status, should not failover + failover.SetHandlerStatus(t.Context(), true) + failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", "handler") + rw.WriteHeader(http.StatusOK) + })) + + recorder = httptest.NewRecorder() + failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + + assert.Equal(t, "handler", recorder.Header().Get("server")) + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestFailoverInvalidStatusCodeRange(t *testing.T) { + _, err := New(&dynamic.Failover{ + Errors: &dynamic.FailoverError{ + Status: []string{"invalid"}, + }, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "strconv") +} diff --git a/pkg/server/service/loadbalancer/failover/response_writer.go b/pkg/server/service/loadbalancer/failover/response_writer.go new file mode 100644 index 0000000000..b3671b79ac --- /dev/null +++ b/pkg/server/service/loadbalancer/failover/response_writer.go @@ -0,0 +1,77 @@ +package failover + +import ( + "bufio" + "fmt" + "net" + "net/http" + + "github.com/traefik/traefik/v3/pkg/types" +) + +type responseWriter struct { + http.ResponseWriter + + needFallback bool + written bool + header http.Header + statusCodeRange types.HTTPCodeRanges +} + +func (r *responseWriter) Write(b []byte) (int, error) { + if !r.written { + r.WriteHeader(http.StatusOK) + } + + if r.needFallback { + // As we will fallback, we can discard the response body. + return len(b), nil + } + + return r.ResponseWriter.Write(b) +} + +func (r *responseWriter) Header() http.Header { + if r.header == nil { + r.header = make(http.Header) + } + + return r.header +} + +func (r *responseWriter) WriteHeader(statusCode int) { + if statusCode >= 100 && statusCode <= 199 && statusCode != http.StatusSwitchingProtocols { + clear(r.header) + + return + } + + if !r.written { + r.written = true + r.needFallback = r.statusCodeRange.Contains(statusCode) + + if !r.needFallback { + for k, v := range r.header { + for _, vv := range v { + r.ResponseWriter.Header().Add(k, vv) + } + } + + r.ResponseWriter.WriteHeader(statusCode) + } + } +} + +func (r *responseWriter) Flush() { + if flusher, ok := r.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := r.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + + return nil, nil, fmt.Errorf("not a hijacker: %T", r.ResponseWriter) +} diff --git a/pkg/server/service/loadbalancer/mirror/mirror.go b/pkg/server/service/loadbalancer/mirror/mirror.go index 9b4dbb804b..bd5f4d70bc 100644 --- a/pkg/server/service/loadbalancer/mirror/mirror.go +++ b/pkg/server/service/loadbalancer/mirror/mirror.go @@ -54,6 +54,10 @@ type mirrorHandler struct { count uint64 } +type clonableRequest interface { + Clone(ctx context.Context) *http.Request +} + func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) { mirrors := m.getActiveMirrors() if len(mirrors) == 0 { @@ -62,21 +66,28 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } logger := log.Ctx(req.Context()) - rr, bytesRead, err := newReusableRequest(req, m.mirrorBody, m.maxBodySize) - if err != nil && !errors.Is(err, errBodyTooLarge) { - http.Error(rw, fmt.Sprintf("%s: creating reusable request: %v", - http.StatusText(http.StatusInternalServerError), err), http.StatusInternalServerError) - return + + var rr clonableRequest = req + + if m.mirrorBody { + var err error + var bytesRead []byte + rr, bytesRead, err = NewReusableRequest(req, m.maxBodySize) + if err != nil && !errors.Is(err, ErrBodyTooLarge) { + logger.Debug().Err(err).Msg("Error while creating reusable request for mirroring") + http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + if errors.Is(err, ErrBodyTooLarge) { + req.Body = io.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body)) + m.handler.ServeHTTP(rw, req) + logger.Debug().Msg("No mirroring, request body larger than allowed size") + return + } } - if errors.Is(err, errBodyTooLarge) { - req.Body = io.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body)) - m.handler.ServeHTTP(rw, req) - logger.Debug().Msg("No mirroring, request body larger than allowed size") - return - } - - m.handler.ServeHTTP(rw, rr.clone(req.Context())) + m.handler.ServeHTTP(rw, rr.Clone(req.Context())) select { case <-req.Context().Done(): @@ -89,7 +100,7 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) { m.routinePool.GoCtx(func(_ context.Context) { for _, handler := range mirrors { // prepare request, update body from buffer - r := rr.clone(req.Context()) + r := rr.Clone(req.Context()) // In ServeHTTP, we rely on the presence of the accessLog datatable found in the request's context // to know whether we should mutate said datatable (and contribute some fields to the log). @@ -192,23 +203,23 @@ func (c contextStopPropagation) Done() <-chan struct{} { return make(chan struct{}) } -// reusableRequest keeps in memory the body of the given request, +// ReusableRequest keeps in memory the body of the given request, // so that the request can be fully cloned by each mirror. -type reusableRequest struct { +type ReusableRequest struct { req *http.Request body []byte } -var errBodyTooLarge = errors.New("request body too large") +var ErrBodyTooLarge = errors.New("request body too large") -// if the returned error is errBodyTooLarge, newReusableRequest also returns the -// bytes that were already consumed from the request's body. -func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) (*reusableRequest, []byte, error) { +// NewReusableRequest returns a new reusable request. If the returned error is ErrBodyTooLarge, NewReusableRequest also returns the +// bytes already consumed from the request's body. +func NewReusableRequest(req *http.Request, maxBodySize int64) (*ReusableRequest, []byte, error) { if req == nil { return nil, nil, errors.New("nil input request") } - if req.Body == nil || req.ContentLength == 0 || !mirrorBody { - return &reusableRequest{req: req}, nil, nil + if req.Body == nil || req.ContentLength == 0 { + return &ReusableRequest{req: req}, nil, nil } // unbounded body size @@ -217,7 +228,7 @@ func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) ( if err != nil { return nil, nil, err } - return &reusableRequest{ + return &ReusableRequest{ req: req, body: body, }, nil, nil @@ -234,17 +245,18 @@ func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) ( // we got an ErrUnexpectedEOF, which means there was less than maxBodySize data to read, // which permits us sending also to all the mirrors later. if errors.Is(err, io.ErrUnexpectedEOF) { - return &reusableRequest{ + return &ReusableRequest{ req: req, body: body[:n], }, nil, nil } // err == nil , which means data size > maxBodySize - return nil, body[:n], errBodyTooLarge + return nil, body[:n], ErrBodyTooLarge } -func (rr reusableRequest) clone(ctx context.Context) *http.Request { +// Clone clones the request. +func (rr ReusableRequest) Clone(ctx context.Context) *http.Request { req := rr.req.Clone(ctx) if rr.body != nil { diff --git a/pkg/server/service/loadbalancer/mirror/mirror_test.go b/pkg/server/service/loadbalancer/mirror/mirror_test.go index 515bafdb77..0f0f934928 100644 --- a/pkg/server/service/loadbalancer/mirror/mirror_test.go +++ b/pkg/server/service/loadbalancer/mirror/mirror_test.go @@ -224,16 +224,16 @@ func TestCloneRequest(t *testing.T) { assert.NoError(t, err) ctx := req.Context() - rr, _, err := newReusableRequest(req, true, defaultMaxBodySize) + rr, _, err := NewReusableRequest(req, defaultMaxBodySize) assert.NoError(t, err) // first call - cloned := rr.clone(ctx) + cloned := rr.Clone(ctx) assert.Equal(t, cloned, req) assert.Nil(t, cloned.Body) // second call - cloned = rr.clone(ctx) + cloned = rr.Clone(ctx) assert.Equal(t, cloned, req) assert.Nil(t, cloned.Body) }) @@ -249,17 +249,17 @@ func TestCloneRequest(t *testing.T) { ctx := req.Context() req.ContentLength = int64(contentLength) - rr, _, err := newReusableRequest(req, true, defaultMaxBodySize) + rr, _, err := NewReusableRequest(req, defaultMaxBodySize) assert.NoError(t, err) // first call - cloned := rr.clone(ctx) + cloned := rr.Clone(ctx) body, err := io.ReadAll(cloned.Body) assert.NoError(t, err) assert.Equal(t, bb, body) // second call - cloned = rr.clone(ctx) + cloned = rr.Clone(ctx) body, err = io.ReadAll(cloned.Body) assert.NoError(t, err) assert.Equal(t, bb, body) @@ -272,7 +272,7 @@ func TestCloneRequest(t *testing.T) { req, err := http.NewRequest(http.MethodPost, "/", buf) assert.NoError(t, err) - _, expectedBytes, err := newReusableRequest(req, true, 2) + _, expectedBytes, err := NewReusableRequest(req, 2) assert.Error(t, err) assert.Equal(t, expectedBytes, bb[:3]) }) @@ -284,7 +284,7 @@ func TestCloneRequest(t *testing.T) { req, err := http.NewRequest(http.MethodPost, "/", buf) assert.NoError(t, err) - rr, expectedBytes, err := newReusableRequest(req, true, 20) + rr, expectedBytes, err := NewReusableRequest(req, 20) assert.NoError(t, err) assert.Nil(t, expectedBytes) assert.Len(t, rr.body, 10) @@ -296,14 +296,14 @@ func TestCloneRequest(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "/", buf) assert.NoError(t, err) - rr, expectedBytes, err := newReusableRequest(req, true, 20) + rr, expectedBytes, err := NewReusableRequest(req, 20) assert.NoError(t, err) assert.Nil(t, expectedBytes) assert.Empty(t, rr.body) }) t.Run("no request given", func(t *testing.T) { - _, _, err := newReusableRequest(nil, true, defaultMaxBodySize) + _, _, err := NewReusableRequest(nil, defaultMaxBodySize) assert.Error(t, err) }) } diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index 599c943e03..ff1141882d 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -210,24 +210,29 @@ func (m *Manager) LaunchHealthCheck(ctx context.Context) { } func (m *Manager) getFailoverServiceHandler(ctx context.Context, serviceName string, config *dynamic.Failover) (http.Handler, error) { - f := failover.New(config.HealthCheck) - serviceHandler, err := m.BuildHTTP(ctx, config.Service) if err != nil { return nil, err } - f.SetHandler(serviceHandler) - - updater, ok := serviceHandler.(healthcheck.StatusUpdater) - if !ok { + updater, implementUpdater := serviceHandler.(healthcheck.StatusUpdater) + isErrorDefined := config.Errors != nil && len(config.Errors.Status) > 0 + if !implementUpdater && !isErrorDefined { return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", config.Service, serviceName, serviceHandler) } - if err := updater.RegisterStatusUpdater(func(up bool) { - f.SetHandlerStatus(ctx, up) - }); err != nil { - return nil, fmt.Errorf("cannot register %v as updater for %v: %w", config.Service, serviceName, err) + f, err := failover.New(config) + if err != nil { + return nil, fmt.Errorf("error creating failover service %v: %w", serviceName, err) + } + f.SetHandler(serviceHandler) + + if implementUpdater { + if err := updater.RegisterStatusUpdater(func(up bool) { + f.SetHandlerStatus(ctx, up) + }); err != nil && !isErrorDefined { + return nil, fmt.Errorf("cannot register %v as updater for %v: %w", config.Service, serviceName, err) + } } fallbackHandler, err := m.BuildHTTP(ctx, config.Fallback) @@ -242,8 +247,8 @@ func (m *Manager) getFailoverServiceHandler(ctx context.Context, serviceName str return f, nil } - fallbackUpdater, ok := fallbackHandler.(healthcheck.StatusUpdater) - if !ok { + fallbackUpdater, implementUpdater := fallbackHandler.(healthcheck.StatusUpdater) + if !implementUpdater { return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", config.Fallback, serviceName, fallbackHandler) }