Failover according to response status code
Some checks are pending
CodeQL / Analyze (push) Waiting to run
Build and Publish Documentation / Doc Process (push) Waiting to run
Build experimental image on branch / build-webui (push) Waiting to run
Build experimental image on branch / Build experimental image on branch (push) Waiting to run

Co-authored-by: juliens <julien.salleyron@gmail.com>
This commit is contained in:
Landry Benguigui 2026-02-09 14:10:06 +01:00 committed by GitHub
parent a4a91344ed
commit 34ae66b9ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 998 additions and 67 deletions

View file

@ -760,7 +760,7 @@ The `mirroring` service type mirrors requests sent to a service to other service
!!! info "Supported Providers"
This service type can be defined currently with the [File](../../../install-configuration/providers/others/file.md) provider or [IngressRoute](../../../routing-configuration/kubernetes/crd/http/ingressroute.md).
```yaml tab="Structured (YAML)"
## Routing configuration
http:
@ -887,9 +887,13 @@ http:
url = "http://private-ip-server-2/"
```
### Failover
### Failover
The `failover` service type forwards all requests to a fallback service when the main service becomes unreachable.
The `failover` service type forwards requests to a fallback service when the main service is unavailable.
Failover can be triggered in two ways:
- **Health check-based**: When the main service becomes unreachable based on [health checks](#health-check).
- **Status code-based**: When the main service responds with specific HTTP status codes defined in the [errors](#errors) configuration.
!!! info "Relation to HealthCheck"
The failover service relies on the HealthCheck system to get notified when its main service becomes unreachable, which means HealthCheck needs to be enabled and functional on the main service. However, HealthCheck does not need to be enabled on the failover service itself for it to be functional. It is only required in order to propagate upwards the information when the failover itself becomes down (i.e. both its main and its fallback are down too).
@ -940,15 +944,15 @@ http:
## Routing configuration
[http.services]
[http.services.app]
[http.services.app.failover.healthCheck]
[http.services.app.failover]
service = "main"
fallback = "backup"
[http.services.app.failover.healthCheck]
[http.services.main]
[http.services.main.loadBalancer]
[http.services.main.loadBalancer.healthCheck]
path = "/health"
path = "/status"
interval = "10s"
timeout = "3s"
[[http.services.main.loadBalancer.servers]]
@ -957,9 +961,163 @@ http:
[http.services.backup]
[http.services.backup.loadBalancer]
[http.services.backup.loadBalancer.healthCheck]
path = "/health"
path = "/status"
interval = "10s"
timeout = "3s"
[[http.services.backup.loadBalancer.servers]]
url = "http://private-ip-server-2/"
```
#### Errors
The `errors` option enables status code-based failover.
When the main service responds with an HTTP status code matching one of the configured ranges, Traefik automatically retries the request on the fallback service.
To support request replay, the request body is buffered up to `maxRequestBodyBytes`.
Requests with bodies larger than this limit receive a `413 Request Entity Too Large` response.
Below is a list of options available for the `errors` option and an example of how to configure it for a failover service:
| Field | Description | Default |
|-----------------------|-------------------------------------------------------------------------------------------------------------------|---------|
| <a id="opt-status-2" href="#opt-status-2" title="#opt-status-2">`status`</a> | List of HTTP status code ranges that trigger failover. Supports single codes (`"500"`) and ranges (`"500-504"`). | None |
| <a id="opt-maxRequestBodyBytes" href="#opt-maxRequestBodyBytes" title="#opt-maxRequestBodyBytes">`maxRequestBodyBytes`</a> | Maximum request body size (in bytes) to buffer for replay to the fallback service. Set to `-1` for no limit. | `-1` |
```yaml tab="Structured (YAML)"
## Routing configuration
http:
services:
app:
failover:
service: main
fallback: backup
errors:
status:
- "500-504"
maxRequestBodyBytes: 1048576
main:
loadBalancer:
servers:
- url: "http://private-ip-server-1/"
backup:
loadBalancer:
servers:
- url: "http://private-ip-server-2/"
```
```toml tab="Structured (TOML)"
## Routing configuration
[http.services]
[http.services.app]
[http.services.app.failover]
service = "main"
fallback = "backup"
[http.services.app.failover.errors]
status = ["500-504"]
maxRequestBodyBytes = 1048576
[http.services.main]
[http.services.main.loadBalancer]
[[http.services.main.loadBalancer.servers]]
url = "http://private-ip-server-1/"
[http.services.backup]
[http.services.backup.loadBalancer]
[[http.services.backup.loadBalancer.servers]]
url = "http://private-ip-server-2/"
```
#### Chaining Failover Services
Failover services can be chained together for multi-level redundancy.
In the following example, if the primary service fails, traffic goes to the secondary service.
If both primary and secondary fail, traffic goes to the tertiary service.
```yaml tab="Structured (YAML)"
## Routing configuration
http:
services:
app:
failover:
healthCheck: {}
service: primary-failover
fallback: tertiary
primary-failover:
failover:
healthCheck: {}
service: primary
fallback: secondary
primary:
loadBalancer:
healthCheck:
path: /health
interval: 10s
timeout: 3s
servers:
- url: "http://primary-server/"
secondary:
loadBalancer:
healthCheck:
path: /health
interval: 10s
timeout: 3s
servers:
- url: "http://secondary-server/"
tertiary:
loadBalancer:
healthCheck:
path: /health
interval: 10s
timeout: 3s
servers:
- url: "http://tertiary-server/"
```
```toml tab="Structured (TOML)"
## Routing configuration
[http.services]
[http.services.app]
[http.services.app.failover]
service = "primary-failover"
fallback = "tertiary"
[http.services.app.failover.healthCheck]
[http.services.primary-failover]
[http.services.primary-failover.failover]
service = "primary"
fallback = "secondary"
[http.services.primary-failover.failover.healthCheck]
[http.services.primary]
[http.services.primary.loadBalancer]
[http.services.primary.loadBalancer.healthCheck]
path = "/health"
interval = "10s"
timeout = "3s"
[[http.services.primary.loadBalancer.servers]]
url = "http://primary-server/"
[http.services.secondary]
[http.services.secondary.loadBalancer]
[http.services.secondary.loadBalancer.healthCheck]
path = "/health"
interval = "10s"
timeout = "3s"
[[http.services.secondary.loadBalancer.servers]]
url = "http://secondary-server/"
[http.services.tertiary]
[http.services.tertiary.loadBalancer]
[http.services.tertiary.loadBalancer.healthCheck]
path = "/health"
interval = "10s"
timeout = "3s"
[[http.services.tertiary.loadBalancer.servers]]
url = "http://tertiary-server/"
```

View file

@ -56,6 +56,9 @@
service = "foobar"
fallback = "foobar"
[http.services.Service01.failover.healthCheck]
[http.services.Service01.failover.errors]
maxRequestBodyBytes = 42
status = ["foobar", "foobar"]
[http.services.Service02]
[http.services.Service02.highestRandomWeight]

View file

@ -70,6 +70,11 @@ http:
service: foobar
fallback: foobar
healthCheck: {}
errors:
maxRequestBodyBytes: 42
status:
- foobar
- foobar
Service02:
highestRandomWeight:
services:

View file

@ -0,0 +1,53 @@
[global]
checkNewVersion = false
sendAnonymousUsage = false
[api]
insecure = true
[log]
level = "DEBUG"
noColor = true
[entryPoints]
[entryPoints.web]
address = ":8000"
[providers.file]
filename = "{{ .SelfFilename }}"
## dynamic configuration ##
[http.routers]
[http.routers.router1]
entrypoints = ["web"]
service = "failover-service"
rule = "Path(`/whoami`)"
[http.services]
# Failover service with health check
[http.services.failover-service]
[http.services.failover-service.failover]
service = "main-service"
fallback = "fallback-service"
[http.services.failover-service.failover.healthCheck]
# Main service with health check enabled
[http.services.main-service]
[http.services.main-service.loadBalancer]
[http.services.main-service.loadBalancer.healthCheck]
path = "/health"
interval = "1s"
timeout = "0.9s"
[[http.services.main-service.loadBalancer.servers]]
url = "http://{{ .MainServer }}:80"
# Fallback service with health check enabled
[http.services.fallback-service]
[http.services.fallback-service.loadBalancer]
[http.services.fallback-service.loadBalancer.healthCheck]
path = "/health"
interval = "1s"
timeout = "0.9s"
[[http.services.fallback-service.loadBalancer.servers]]
url = "http://{{ .FallbackServer }}:80"

View file

@ -0,0 +1,47 @@
[global]
checkNewVersion = false
sendAnonymousUsage = false
[api]
insecure = true
[log]
level = "DEBUG"
noColor = true
[entryPoints]
[entryPoints.web]
address = ":8000"
[providers.file]
filename = "{{ .SelfFilename }}"
## dynamic configuration ##
[http.routers]
[http.routers.router1]
entrypoints = ["web"]
service = "failover-service"
rule = "PathPrefix(`/`)"
[http.services]
# Failover service with error-based failover (status codes)
[http.services.failover-service]
[http.services.failover-service.failover]
service = "main-service"
fallback = "fallback-service"
[http.services.failover-service.failover.errors]
status = ["500-504"]
maxRequestBodyBytes = 1048576 # 1MB
# Main service (no health check - failover based on status codes only)
[http.services.main-service]
[http.services.main-service.loadBalancer]
[[http.services.main-service.loadBalancer.servers]]
url = "{{ .MainServer }}"
# Fallback service
[http.services.fallback-service]
[http.services.fallback-service.loadBalancer]
[[http.services.fallback-service.loadBalancer.servers]]
url = "{{ .FallbackServer }}"

View file

@ -2365,6 +2365,173 @@ func (s *SimpleSuite) TestEncodedCharactersDifferentEntryPoints() {
}
}
func (s *SimpleSuite) TestFailoverService() {
s.createComposeProject("base")
s.composeUp()
defer s.composeDown()
whoami1IP := s.getComposeServiceIP("whoami1")
whoami2IP := s.getComposeServiceIP("whoami2")
file := s.adaptFile("fixtures/failover.toml", struct {
MainServer string
FallbackServer string
}{
MainServer: whoami1IP,
FallbackServer: whoami2IP,
})
s.traefikCmd(withConfigFile(file))
// Wait for Traefik to be ready
err := try.GetRequest("http://127.0.0.1:8080/api/http/services", 2*time.Second, try.BodyContains("failover-service"))
require.NoError(s.T(), err)
// Test 1: When main service is healthy, traffic should go to main
var primaryCount, fallbackCount int
for range 5 {
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil)
require.NoError(s.T(), err)
response, err := http.DefaultClient.Do(req)
require.NoError(s.T(), err)
assert.Equal(s.T(), http.StatusOK, response.StatusCode)
body, err := io.ReadAll(response.Body)
require.NoError(s.T(), err)
if strings.Contains(string(body), whoami1IP) {
primaryCount++
}
if strings.Contains(string(body), whoami2IP) {
fallbackCount++
}
}
// All requests should go to the main service (whoami1)
assert.Equal(s.T(), 5, primaryCount, "Expected all requests to go to main service")
assert.Equal(s.T(), 0, fallbackCount, "Expected no requests to go to fallback service")
// Test 2: Stop the main service to trigger failover via health check
s.composeStop("whoami1")
// Wait for health check to detect the main service is down
time.Sleep(3 * time.Second)
// Now all traffic should go to the fallback service (whoami2)
primaryCount, fallbackCount = 0, 0
for range 5 {
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil)
require.NoError(s.T(), err)
response, err := http.DefaultClient.Do(req)
require.NoError(s.T(), err)
assert.Equal(s.T(), http.StatusOK, response.StatusCode)
body, err := io.ReadAll(response.Body)
require.NoError(s.T(), err)
if strings.Contains(string(body), whoami1IP) {
primaryCount++
}
if strings.Contains(string(body), whoami2IP) {
fallbackCount++
}
}
assert.Equal(s.T(), 0, primaryCount, "Expected no requests to go to main service when down")
assert.Equal(s.T(), 5, fallbackCount, "Expected all requests to go to fallback service")
// Test 3: Restart main service and verify traffic returns to main
s.composeUp("whoami1")
// Wait for health check to detect the main service is back up
time.Sleep(3 * time.Second)
primaryCount, fallbackCount = 0, 0
for range 5 {
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil)
require.NoError(s.T(), err)
response, err := http.DefaultClient.Do(req)
require.NoError(s.T(), err)
assert.Equal(s.T(), http.StatusOK, response.StatusCode)
body, err := io.ReadAll(response.Body)
require.NoError(s.T(), err)
if strings.Contains(string(body), whoami1IP) {
primaryCount++
}
if strings.Contains(string(body), whoami2IP) {
fallbackCount++
}
}
// Traffic should return to the main service
assert.Equal(s.T(), 5, primaryCount, "Expected all requests to return to main service when back up")
assert.Equal(s.T(), 0, fallbackCount, "Expected no requests to go to fallback service")
// Test 4: Stop both services and verify we get 503
s.composeStop("whoami1")
s.composeStop("whoami2")
// Wait for health checks to detect both services are down
time.Sleep(3 * time.Second)
// Request should return 503 Service Unavailable when both services are down
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", nil)
require.NoError(s.T(), err)
response, err := http.DefaultClient.Do(req)
require.NoError(s.T(), err)
assert.Equal(s.T(), http.StatusServiceUnavailable, response.StatusCode)
}
func (s *SimpleSuite) TestFailoverServiceWithStatusCode() {
var mainCallCount, fallbackCallCount atomic.Int32
// Create a test server that returns 503 to trigger error-based failover
mainServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
mainCallCount.Add(1)
rw.WriteHeader(http.StatusServiceUnavailable)
_, _ = rw.Write([]byte("main service unavailable"))
}))
defer mainServer.Close()
// Create a fallback server that returns 200
fallbackServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
fallbackCallCount.Add(1)
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("fallback service"))
}))
defer fallbackServer.Close()
file := s.adaptFile("fixtures/failover_statuscode.toml", struct {
MainServer string
FallbackServer string
}{
MainServer: mainServer.URL,
FallbackServer: fallbackServer.URL,
})
s.traefikCmd(withConfigFile(file))
// Wait for Traefik to be ready and verify the configuration is loaded
err := try.GetRequest("http://127.0.0.1:8080/api/rawdata", 10*time.Second, try.BodyContains("PathPrefix"))
require.NoError(s.T(), err)
// Make a request - should failover to fallback because main returns 503
err = try.GetRequest("http://127.0.0.1:8000/", 5*time.Second, try.BodyContains("fallback service"))
require.NoError(s.T(), err)
// Main was called but returned 503, triggering failover to fallback
assert.GreaterOrEqual(s.T(), mainCallCount.Load(), int32(1), "Main service should have been called at least once")
assert.GreaterOrEqual(s.T(), fallbackCallCount.Load(), int32(1), "Fallback service should have been called at least once")
}
func (s *SimpleSuite) TestServiceMiddleware() {
s.createComposeProject("base")

View file

@ -10,6 +10,7 @@ import (
traefiktls "github.com/traefik/traefik/v3/pkg/tls"
"github.com/traefik/traefik/v3/pkg/types"
"google.golang.org/grpc/codes"
"k8s.io/utils/ptr"
)
const (
@ -28,6 +29,8 @@ const (
MirroringDefaultMirrorBody = true
// MirroringDefaultMaxBodySize is the Mirroring.MaxBodySize option default value.
MirroringDefaultMaxBodySize int64 = -1
// FailoverErrorsDefaultMaxRequestBodyBytes is the Failover.Errors.MaxBodySize option default value.
FailoverErrorsDefaultMaxRequestBodyBytes int64 = -1
)
// +k8s:deepcopy-gen=true
@ -195,9 +198,23 @@ func (m *Mirroring) SetDefaults() {
// Failover holds the Failover configuration.
type Failover struct {
Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty" export:"true"`
Fallback string `json:"fallback,omitempty" toml:"fallback,omitempty" yaml:"fallback,omitempty" export:"true"`
HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:"healthCheck,omitempty" yaml:"healthCheck,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"`
Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty" export:"true"`
Fallback string `json:"fallback,omitempty" toml:"fallback,omitempty" yaml:"fallback,omitempty" export:"true"`
HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:"healthCheck,omitempty" yaml:"healthCheck,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"`
Errors *FailoverError `json:"errors,omitempty" toml:"errors,omitempty" yaml:"errors,omitempty" export:"true"`
}
// +k8s:deepcopy-gen=true
// FailoverError holds errors configuration.
type FailoverError struct {
MaxRequestBodyBytes *int64 `json:"maxRequestBodyBytes,omitempty" toml:"maxRequestBodyBytes,omitempty" yaml:"maxRequestBodyBytes,omitempty" export:"true"`
Status []string `json:"status,omitempty" toml:"status,omitempty" yaml:"status,omitempty" export:"true"`
}
// SetDefaults Default values for a WRRService.
func (m *FailoverError) SetDefaults() {
m.MaxRequestBodyBytes = ptr.To(FailoverErrorsDefaultMaxRequestBodyBytes)
}
// +k8s:deepcopy-gen=true

View file

@ -358,6 +358,11 @@ func (in *Failover) DeepCopyInto(out *Failover) {
*out = new(HealthCheck)
**out = **in
}
if in.Errors != nil {
in, out := &in.Errors, &out.Errors
*out = new(FailoverError)
(*in).DeepCopyInto(*out)
}
return
}
@ -371,6 +376,32 @@ func (in *Failover) DeepCopy() *Failover {
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *FailoverError) DeepCopyInto(out *FailoverError) {
*out = *in
if in.MaxRequestBodyBytes != nil {
in, out := &in.MaxRequestBodyBytes, &out.MaxRequestBodyBytes
*out = new(int64)
**out = **in
}
if in.Status != nil {
in, out := &in.Status, &out.Status
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FailoverError.
func (in *FailoverError) DeepCopy() *FailoverError {
if in == nil {
return nil
}
out := new(FailoverError)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) {
*out = *in

View file

@ -3,11 +3,14 @@ package failover
import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/server/service/loadbalancer/mirror"
"github.com/traefik/traefik/v3/pkg/types"
)
// Failover is an http.Handler that can forward requests to the fallback handler
@ -25,13 +28,33 @@ type Failover struct {
fallbackStatusMu sync.RWMutex
fallbackStatus bool
statusCode types.HTTPCodeRanges
maxRequestBodyBytes int64
}
// New creates a new Failover handler.
func New(hc *dynamic.HealthCheck) *Failover {
return &Failover{
wantsHealthCheck: hc != nil,
func New(config *dynamic.Failover) (*Failover, error) {
f := &Failover{wantsHealthCheck: config.HealthCheck != nil}
if config.Errors != nil {
if len(config.Errors.Status) > 0 {
httpCodeRanges, err := types.NewHTTPCodeRanges(config.Errors.Status)
if err != nil {
return nil, fmt.Errorf("creating HTTP code ranges: %w", err)
}
f.statusCode = httpCodeRanges
}
maxRequestBodyBytes := dynamic.FailoverErrorsDefaultMaxRequestBodyBytes
if config.Errors.MaxRequestBodyBytes != nil {
maxRequestBodyBytes = *config.Errors.MaxRequestBodyBytes
}
f.maxRequestBodyBytes = maxRequestBodyBytes
}
return f, nil
}
// RegisterStatusUpdater adds fn to the list of hooks that are run when the
@ -53,8 +76,33 @@ func (f *Failover) ServeHTTP(w http.ResponseWriter, req *http.Request) {
f.handlerStatusMu.RUnlock()
if handlerStatus {
f.handler.ServeHTTP(w, req)
return
if len(f.statusCode) == 0 {
f.handler.ServeHTTP(w, req)
return
}
// TODO: move reusable request to a common package at some point.
rr, _, err := mirror.NewReusableRequest(req, f.maxRequestBodyBytes)
if err != nil && !errors.Is(err, mirror.ErrBodyTooLarge) {
log.Ctx(req.Context()).Debug().Err(err).Msg("Error while creating reusable request for failover")
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if errors.Is(err, mirror.ErrBodyTooLarge) {
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
return
}
rw := &responseWriter{ResponseWriter: w, statusCodeRange: f.statusCode}
f.handler.ServeHTTP(rw, rr.Clone(req.Context()))
if !rw.needFallback {
return
}
req = rr.Clone(req.Context())
}
f.fallbackStatusMu.RLock()

View file

@ -1,6 +1,8 @@
package failover
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
@ -26,7 +28,10 @@ func (r *responseRecorder) WriteHeader(statusCode int) {
}
func TestFailover(t *testing.T) {
failover := New(&dynamic.HealthCheck{})
failover, err := New(&dynamic.Failover{
HealthCheck: &dynamic.HealthCheck{},
})
require.NoError(t, err)
status := true
require.NoError(t, failover.RegisterStatusUpdater(func(up bool) {
@ -73,7 +78,8 @@ func TestFailover(t *testing.T) {
}
func TestFailoverDownThenUp(t *testing.T) {
failover := New(nil)
failover, err := New(&dynamic.Failover{})
require.NoError(t, err)
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "handler")
@ -112,7 +118,10 @@ func TestFailoverDownThenUp(t *testing.T) {
}
func TestFailoverPropagate(t *testing.T) {
failover := New(&dynamic.HealthCheck{})
failover, err := New(&dynamic.Failover{
HealthCheck: &dynamic.HealthCheck{},
})
require.NoError(t, err)
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "handler")
rw.WriteHeader(http.StatusOK)
@ -122,13 +131,14 @@ func TestFailoverPropagate(t *testing.T) {
rw.WriteHeader(http.StatusOK)
}))
topFailover := New(nil)
topFailover, err := New(&dynamic.Failover{})
require.NoError(t, err)
topFailover.SetHandler(failover)
topFailover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "topFailover")
rw.WriteHeader(http.StatusOK)
}))
err := failover.RegisterStatusUpdater(func(up bool) {
err = failover.RegisterStatusUpdater(func(up bool) {
topFailover.SetHandlerStatus(t.Context(), up)
})
require.NoError(t, err)
@ -161,3 +171,301 @@ func TestFailoverPropagate(t *testing.T) {
assert.Equal(t, 1, recorder.save["topFailover"])
assert.Equal(t, []int{200}, recorder.status)
}
func TestFailoverStatusCode(t *testing.T) {
testCases := []struct {
desc string
statusCode []string
handlerStatusCode int
expectedHandler string
expectedStatusCode int
expectedResponseBody string
}{
{
desc: "main handler returns 503, failover triggered",
statusCode: []string{"503"},
handlerStatusCode: 503,
expectedHandler: "fallback",
expectedStatusCode: 200,
expectedResponseBody: "fallback response",
},
{
desc: "main handler returns 200, failover not triggered",
statusCode: []string{"503"},
handlerStatusCode: 200,
expectedHandler: "handler",
expectedStatusCode: 200,
expectedResponseBody: "main response",
},
{
desc: "main handler returns 500, failover not triggered",
statusCode: []string{"503"},
handlerStatusCode: 500,
expectedHandler: "handler",
expectedStatusCode: 500,
expectedResponseBody: "main response",
},
{
desc: "multiple status codes, 503 triggers failover",
statusCode: []string{"500", "502", "503", "504"},
handlerStatusCode: 503,
expectedHandler: "fallback",
expectedStatusCode: 200,
expectedResponseBody: "fallback response",
},
{
desc: "multiple status codes, 502 triggers failover",
statusCode: []string{"500", "502", "503", "504"},
handlerStatusCode: 502,
expectedHandler: "fallback",
expectedStatusCode: 200,
expectedResponseBody: "fallback response",
},
{
desc: "multiple status codes, 404 does not trigger failover",
statusCode: []string{"500", "502", "503", "504"},
handlerStatusCode: 404,
expectedHandler: "handler",
expectedStatusCode: 404,
expectedResponseBody: "main response",
},
{
desc: "status code range 500-504, 503 triggers failover",
statusCode: []string{"500-504"},
handlerStatusCode: 503,
expectedHandler: "fallback",
expectedStatusCode: 200,
expectedResponseBody: "fallback response",
},
{
desc: "status code range 500-504, 404 does not trigger failover",
statusCode: []string{"500-504"},
handlerStatusCode: 404,
expectedHandler: "handler",
expectedStatusCode: 404,
expectedResponseBody: "main response",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
failover, err := New(&dynamic.Failover{
Errors: &dynamic.FailoverError{
Status: test.statusCode,
},
})
require.NoError(t, err)
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "handler")
rw.WriteHeader(test.handlerStatusCode)
_, err := rw.Write([]byte("main response"))
require.NoError(t, err)
}))
failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "fallback")
rw.WriteHeader(http.StatusOK)
_, err := rw.Write([]byte("fallback response"))
require.NoError(t, err)
}))
recorder := httptest.NewRecorder()
failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, test.expectedHandler, recorder.Header().Get("server"))
assert.Equal(t, test.expectedStatusCode, recorder.Code)
assert.Equal(t, test.expectedResponseBody, recorder.Body.String())
})
}
}
func TestFailoverStatusCodeWithRequestBody(t *testing.T) {
testCases := []struct {
desc string
statusCode []string
handlerStatusCode int
requestBody string
expectedHandler string
}{
{
desc: "request body replayed to fallback handler",
statusCode: []string{"503"},
handlerStatusCode: 503,
requestBody: "test request body",
expectedHandler: "fallback",
},
{
desc: "request body used by main handler",
statusCode: []string{"503"},
handlerStatusCode: 200,
requestBody: "test request body",
expectedHandler: "handler",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
var receivedBody string
maxBody := int64(-1)
failover, err := New(&dynamic.Failover{
Errors: &dynamic.FailoverError{
Status: test.statusCode,
MaxRequestBodyBytes: &maxBody,
},
})
require.NoError(t, err)
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
receivedBody = string(body)
rw.Header().Set("server", "handler")
rw.WriteHeader(test.handlerStatusCode)
}))
failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
body, _ := io.ReadAll(req.Body)
receivedBody = string(body)
rw.Header().Set("server", "fallback")
rw.WriteHeader(http.StatusOK)
}))
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(test.requestBody))
failover.ServeHTTP(recorder, req)
assert.Equal(t, test.expectedHandler, recorder.Header().Get("server"))
assert.Equal(t, test.requestBody, receivedBody)
})
}
}
func TestFailoverStatusCodeMaxBodySize(t *testing.T) {
testCases := []struct {
desc string
maxBodySize int64
requestBody string
expectedStatusCode int
expectedMessage string
}{
{
desc: "request body within limit",
maxBodySize: 100,
requestBody: "small body",
expectedStatusCode: 200,
expectedMessage: "",
},
{
desc: "request body exceeds limit",
maxBodySize: 5,
requestBody: "this body is too large",
expectedStatusCode: http.StatusRequestEntityTooLarge,
expectedMessage: "Request body too large\n",
},
{
desc: "zero body size limit",
maxBodySize: 0,
requestBody: "any size body",
expectedStatusCode: http.StatusRequestEntityTooLarge,
expectedMessage: "Request body too large\n",
},
{
desc: "no body size limit",
maxBodySize: -1,
requestBody: "any size body should work",
expectedStatusCode: 200,
expectedMessage: "",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
maxBody := test.maxBodySize
failover, err := New(&dynamic.Failover{
Errors: &dynamic.FailoverError{
Status: []string{"503"},
MaxRequestBodyBytes: &maxBody,
},
})
require.NoError(t, err)
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusServiceUnavailable)
}))
failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(test.requestBody))
req.ContentLength = int64(len(test.requestBody))
failover.ServeHTTP(recorder, req)
assert.Equal(t, test.expectedStatusCode, recorder.Code)
if test.expectedMessage != "" {
assert.Equal(t, test.expectedMessage, recorder.Body.String())
}
})
}
}
func TestFailoverStatusCodeWithHealthCheck(t *testing.T) {
failover, err := New(&dynamic.Failover{
HealthCheck: &dynamic.HealthCheck{},
Errors: &dynamic.FailoverError{
Status: []string{"503"},
},
})
require.NoError(t, err)
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "handler")
rw.WriteHeader(http.StatusServiceUnavailable)
}))
failover.SetFallbackHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "fallback")
rw.WriteHeader(http.StatusOK)
}))
// Test 1: Handler is up, returns 503, should failover based on status code
recorder := httptest.NewRecorder()
failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, "fallback", recorder.Header().Get("server"))
assert.Equal(t, http.StatusOK, recorder.Code)
// Test 2: Handler is marked down via health check, should failover
failover.SetHandlerStatus(t.Context(), false)
recorder = httptest.NewRecorder()
failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, "fallback", recorder.Header().Get("server"))
assert.Equal(t, http.StatusOK, recorder.Code)
// Test 3: Handler is back up but returns non-503 status, should not failover
failover.SetHandlerStatus(t.Context(), true)
failover.SetHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "handler")
rw.WriteHeader(http.StatusOK)
}))
recorder = httptest.NewRecorder()
failover.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, "handler", recorder.Header().Get("server"))
assert.Equal(t, http.StatusOK, recorder.Code)
}
func TestFailoverInvalidStatusCodeRange(t *testing.T) {
_, err := New(&dynamic.Failover{
Errors: &dynamic.FailoverError{
Status: []string{"invalid"},
},
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "strconv")
}

View file

@ -0,0 +1,77 @@
package failover
import (
"bufio"
"fmt"
"net"
"net/http"
"github.com/traefik/traefik/v3/pkg/types"
)
type responseWriter struct {
http.ResponseWriter
needFallback bool
written bool
header http.Header
statusCodeRange types.HTTPCodeRanges
}
func (r *responseWriter) Write(b []byte) (int, error) {
if !r.written {
r.WriteHeader(http.StatusOK)
}
if r.needFallback {
// As we will fallback, we can discard the response body.
return len(b), nil
}
return r.ResponseWriter.Write(b)
}
func (r *responseWriter) Header() http.Header {
if r.header == nil {
r.header = make(http.Header)
}
return r.header
}
func (r *responseWriter) WriteHeader(statusCode int) {
if statusCode >= 100 && statusCode <= 199 && statusCode != http.StatusSwitchingProtocols {
clear(r.header)
return
}
if !r.written {
r.written = true
r.needFallback = r.statusCodeRange.Contains(statusCode)
if !r.needFallback {
for k, v := range r.header {
for _, vv := range v {
r.ResponseWriter.Header().Add(k, vv)
}
}
r.ResponseWriter.WriteHeader(statusCode)
}
}
}
func (r *responseWriter) Flush() {
if flusher, ok := r.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := r.ResponseWriter.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", r.ResponseWriter)
}

View file

@ -54,6 +54,10 @@ type mirrorHandler struct {
count uint64
}
type clonableRequest interface {
Clone(ctx context.Context) *http.Request
}
func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
mirrors := m.getActiveMirrors()
if len(mirrors) == 0 {
@ -62,21 +66,28 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
logger := log.Ctx(req.Context())
rr, bytesRead, err := newReusableRequest(req, m.mirrorBody, m.maxBodySize)
if err != nil && !errors.Is(err, errBodyTooLarge) {
http.Error(rw, fmt.Sprintf("%s: creating reusable request: %v",
http.StatusText(http.StatusInternalServerError), err), http.StatusInternalServerError)
return
var rr clonableRequest = req
if m.mirrorBody {
var err error
var bytesRead []byte
rr, bytesRead, err = NewReusableRequest(req, m.maxBodySize)
if err != nil && !errors.Is(err, ErrBodyTooLarge) {
logger.Debug().Err(err).Msg("Error while creating reusable request for mirroring")
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if errors.Is(err, ErrBodyTooLarge) {
req.Body = io.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body))
m.handler.ServeHTTP(rw, req)
logger.Debug().Msg("No mirroring, request body larger than allowed size")
return
}
}
if errors.Is(err, errBodyTooLarge) {
req.Body = io.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body))
m.handler.ServeHTTP(rw, req)
logger.Debug().Msg("No mirroring, request body larger than allowed size")
return
}
m.handler.ServeHTTP(rw, rr.clone(req.Context()))
m.handler.ServeHTTP(rw, rr.Clone(req.Context()))
select {
case <-req.Context().Done():
@ -89,7 +100,7 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
m.routinePool.GoCtx(func(_ context.Context) {
for _, handler := range mirrors {
// prepare request, update body from buffer
r := rr.clone(req.Context())
r := rr.Clone(req.Context())
// In ServeHTTP, we rely on the presence of the accessLog datatable found in the request's context
// to know whether we should mutate said datatable (and contribute some fields to the log).
@ -192,23 +203,23 @@ func (c contextStopPropagation) Done() <-chan struct{} {
return make(chan struct{})
}
// reusableRequest keeps in memory the body of the given request,
// ReusableRequest keeps in memory the body of the given request,
// so that the request can be fully cloned by each mirror.
type reusableRequest struct {
type ReusableRequest struct {
req *http.Request
body []byte
}
var errBodyTooLarge = errors.New("request body too large")
var ErrBodyTooLarge = errors.New("request body too large")
// if the returned error is errBodyTooLarge, newReusableRequest also returns the
// bytes that were already consumed from the request's body.
func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) (*reusableRequest, []byte, error) {
// NewReusableRequest returns a new reusable request. If the returned error is ErrBodyTooLarge, NewReusableRequest also returns the
// bytes already consumed from the request's body.
func NewReusableRequest(req *http.Request, maxBodySize int64) (*ReusableRequest, []byte, error) {
if req == nil {
return nil, nil, errors.New("nil input request")
}
if req.Body == nil || req.ContentLength == 0 || !mirrorBody {
return &reusableRequest{req: req}, nil, nil
if req.Body == nil || req.ContentLength == 0 {
return &ReusableRequest{req: req}, nil, nil
}
// unbounded body size
@ -217,7 +228,7 @@ func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) (
if err != nil {
return nil, nil, err
}
return &reusableRequest{
return &ReusableRequest{
req: req,
body: body,
}, nil, nil
@ -234,17 +245,18 @@ func newReusableRequest(req *http.Request, mirrorBody bool, maxBodySize int64) (
// we got an ErrUnexpectedEOF, which means there was less than maxBodySize data to read,
// which permits us sending also to all the mirrors later.
if errors.Is(err, io.ErrUnexpectedEOF) {
return &reusableRequest{
return &ReusableRequest{
req: req,
body: body[:n],
}, nil, nil
}
// err == nil , which means data size > maxBodySize
return nil, body[:n], errBodyTooLarge
return nil, body[:n], ErrBodyTooLarge
}
func (rr reusableRequest) clone(ctx context.Context) *http.Request {
// Clone clones the request.
func (rr ReusableRequest) Clone(ctx context.Context) *http.Request {
req := rr.req.Clone(ctx)
if rr.body != nil {

View file

@ -224,16 +224,16 @@ func TestCloneRequest(t *testing.T) {
assert.NoError(t, err)
ctx := req.Context()
rr, _, err := newReusableRequest(req, true, defaultMaxBodySize)
rr, _, err := NewReusableRequest(req, defaultMaxBodySize)
assert.NoError(t, err)
// first call
cloned := rr.clone(ctx)
cloned := rr.Clone(ctx)
assert.Equal(t, cloned, req)
assert.Nil(t, cloned.Body)
// second call
cloned = rr.clone(ctx)
cloned = rr.Clone(ctx)
assert.Equal(t, cloned, req)
assert.Nil(t, cloned.Body)
})
@ -249,17 +249,17 @@ func TestCloneRequest(t *testing.T) {
ctx := req.Context()
req.ContentLength = int64(contentLength)
rr, _, err := newReusableRequest(req, true, defaultMaxBodySize)
rr, _, err := NewReusableRequest(req, defaultMaxBodySize)
assert.NoError(t, err)
// first call
cloned := rr.clone(ctx)
cloned := rr.Clone(ctx)
body, err := io.ReadAll(cloned.Body)
assert.NoError(t, err)
assert.Equal(t, bb, body)
// second call
cloned = rr.clone(ctx)
cloned = rr.Clone(ctx)
body, err = io.ReadAll(cloned.Body)
assert.NoError(t, err)
assert.Equal(t, bb, body)
@ -272,7 +272,7 @@ func TestCloneRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
_, expectedBytes, err := newReusableRequest(req, true, 2)
_, expectedBytes, err := NewReusableRequest(req, 2)
assert.Error(t, err)
assert.Equal(t, expectedBytes, bb[:3])
})
@ -284,7 +284,7 @@ func TestCloneRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
rr, expectedBytes, err := newReusableRequest(req, true, 20)
rr, expectedBytes, err := NewReusableRequest(req, 20)
assert.NoError(t, err)
assert.Nil(t, expectedBytes)
assert.Len(t, rr.body, 10)
@ -296,14 +296,14 @@ func TestCloneRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", buf)
assert.NoError(t, err)
rr, expectedBytes, err := newReusableRequest(req, true, 20)
rr, expectedBytes, err := NewReusableRequest(req, 20)
assert.NoError(t, err)
assert.Nil(t, expectedBytes)
assert.Empty(t, rr.body)
})
t.Run("no request given", func(t *testing.T) {
_, _, err := newReusableRequest(nil, true, defaultMaxBodySize)
_, _, err := NewReusableRequest(nil, defaultMaxBodySize)
assert.Error(t, err)
})
}

View file

@ -210,24 +210,29 @@ func (m *Manager) LaunchHealthCheck(ctx context.Context) {
}
func (m *Manager) getFailoverServiceHandler(ctx context.Context, serviceName string, config *dynamic.Failover) (http.Handler, error) {
f := failover.New(config.HealthCheck)
serviceHandler, err := m.BuildHTTP(ctx, config.Service)
if err != nil {
return nil, err
}
f.SetHandler(serviceHandler)
updater, ok := serviceHandler.(healthcheck.StatusUpdater)
if !ok {
updater, implementUpdater := serviceHandler.(healthcheck.StatusUpdater)
isErrorDefined := config.Errors != nil && len(config.Errors.Status) > 0
if !implementUpdater && !isErrorDefined {
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", config.Service, serviceName, serviceHandler)
}
if err := updater.RegisterStatusUpdater(func(up bool) {
f.SetHandlerStatus(ctx, up)
}); err != nil {
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", config.Service, serviceName, err)
f, err := failover.New(config)
if err != nil {
return nil, fmt.Errorf("error creating failover service %v: %w", serviceName, err)
}
f.SetHandler(serviceHandler)
if implementUpdater {
if err := updater.RegisterStatusUpdater(func(up bool) {
f.SetHandlerStatus(ctx, up)
}); err != nil && !isErrorDefined {
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", config.Service, serviceName, err)
}
}
fallbackHandler, err := m.BuildHTTP(ctx, config.Fallback)
@ -242,8 +247,8 @@ func (m *Manager) getFailoverServiceHandler(ctx context.Context, serviceName str
return f, nil
}
fallbackUpdater, ok := fallbackHandler.(healthcheck.StatusUpdater)
if !ok {
fallbackUpdater, implementUpdater := fallbackHandler.(healthcheck.StatusUpdater)
if !implementUpdater {
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", config.Fallback, serviceName, fallbackHandler)
}